本文主要是介绍register_backward_hook()和register_forward_hook(),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
结论:
一:register_forward_hook()在指定网络层执行完前向传播后调用钩子函数
二:
1:register_backward_hook()在指定网络层执行完backward()之后调用钩子函数
2:register_backward_hook()返回的grad_input是关于所有输入变量的梯度,也就是说grad_input是个元组,包含有对该层网络的权重weight的梯度,偏置bias的梯度,以及该层输入x的梯度
3:grad_input元组中,关于权重,偏置和输入x的梯度的顺序,不同网络层是不一样的,比如
对于nn.Linear层,grad_input是按照bias的梯度,x的梯度,weight的梯度排列的。而在nn.Conv2d()层,grad_input是按照x的梯度,weight的梯度,bias的梯度排列的。
最后,这两个函数都会返回一个句柄handle,这个handle有一个remove()方法,用于将钩子函数从网络中去除。
一:
register_forward_hook(hook_fuc)中的hook_fuc函数需要有三个hook_func(model, input, output)这里的input和output是比较好理解的,因为是前向传播,所以input就是输入网络层的输入,output就是该层网络的输出。(注意,hook_func是在该层网络前向传播完成以后执行)
二:
1:register_backward_hook(hook_func),首先要明确,这里的hook_func只有在网络执行backward()之后才会调用,这里的hook_func(model,grad_input,grad_output)也包含三个参数,model即需要调用的网络层,grad_input是该层网络的所有输入的梯度,也就是包含有该层网络输入偏差的梯度,该层网络输入变量x的梯度(也就是变量x所有权重之和),以及改成网络权重的梯度(也就是输入x);而grad_output是指该层网络输出的梯度(这里可能有点疑问,接下来重点讲一下)。
2:在nn.Linear层可以看到,这里的grad_input实际上是一个元组,包含三a:biasb:输入变量x的梯度(实际上就是weight的c:模型权重weight的梯度(实际上就是x)。
注意,我们这里y在求导的时候,是给了权重[[2,1]]的,计算过程如下:
3:我们可以看在nn.Conv2d层:结果如下
这篇关于register_backward_hook()和register_forward_hook()的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!