pytorch中的钩子(Hook)
首先明确一点,有哪些hook?
1. torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package
2. torch.nn.Module.register_backward_hook (Python method, in torch.nn)
3. torch.nn.Module.register_forward_hook
第一个是register_hook,是针对Variable对象的,后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。
也就是说,这个函数是拥有改变梯度值的威力的!
至于register_forward_hook和register_backward_hook的用法和这个大同小异。只不过对象从Variable改成了你自己定义的nn.Module。
当你训练一个网络,想要提取中间层的参数、或者特征图的时候,使用hook就能派上用场了
相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。
一、Hook函数概念
Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。
Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;