欢迎光临散文网 会员登陆 & 注册

pytorch中的钩子(Hook)

2022-10-30 20:37 作者:熊二爱光头强丫  | 我要投稿

首先明确一点,有哪些hook?

1. torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package

2. torch.nn.Module.register_backward_hook (Python method, in torch.nn)

3. torch.nn.Module.register_forward_hook

第一个是register_hook,是针对Variable对象的,后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。

也就是说,这个函数是拥有改变梯度值的威力的!


至于register_forward_hook和register_backward_hook的用法和这个大同小异。只不过对象从Variable改成了你自己定义的nn.Module。

当你训练一个网络,想要提取中间层的参数、或者特征图的时候,使用hook就能派上用场了

相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。

一、Hook函数概念

Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。

Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;


pytorch中的钩子(Hook)的评论 (共 条)

分享到微博请遵守国家法律