pytorch中的hook机制是什么
pytorch中的hook机制是什么
本篇内容介绍了“pytorch中的hook机制是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!
1、hook背景
Hook
被成为钩子机制,这不是pytorch的首创,在Windows
的编程中已经被普遍采用,包括进程内钩子和全局钩子。按照自己的理解,hook的作用是通过系统来维护一个链表,使得用户拦截(获取)通信消息,用于处理事件。
pytorch中包含forward
和backward
两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。
2、源码阅读
register_forward_hook()
函数必须在forward()函数调用之前被使用,因为这个函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是这个函数在forward()之后就没有作用了!!!):
作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。
defregister_forward_hook(self,hook):r"""Registersaforwardhookonthemodule.Thehookwillbecalledeverytimeafter:func:`forward`hascomputedanoutput.Itshouldhavethefollowingsignature::hook(module,input,output)->NoneormodifiedoutputThehookcanmodifytheoutput.Itcanmodifytheinputinplacebutitwillnothaveeffectonforwardsincethisiscalledafter:func:`forward`iscalled.Returns::class:`torch.utils.hooks.RemovableHandle`:ahandlethatcanbeusedtoremovetheaddedhookbycalling``handle.remove()``"""handle=hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id]=hookreturnhandle
3、定义一个用于测试hooker的类
如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward
中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),__init__()
中调用initialize
函数对所有层进行初始化。
注意:在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。
classTestForHook(nn.Module):def__init__(self):super().__init__()self.linear_1=nn.Linear(in_features=2,out_features=2)self.linear_2=nn.Linear(in_features=2,out_features=1)self.relu=nn.ReLU()self.relu6=nn.ReLU6()self.initialize()defforward(self,x):linear_1=self.linear_1(x)linear_2=self.linear_2(linear_1)relu=self.relu(linear_2)relu_6=self.relu6(relu)layers_in=(x,linear_1,linear_2)layers_out=(linear_1,linear_2,relu)returnrelu_6,layers_in,layers_outdefinitialize(self):"""定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight=torch.nn.Parameter(torch.FloatTensor([[1,1],[1,1]]))self.linear_1.bias=torch.nn.Parameter(torch.FloatTensor([1,1]))self.linear_2.weight=torch.nn.Parameter(torch.FloatTensor([[1,1]]))self.linear_2.bias=torch.nn.Parameter(torch.FloatTensor([1]))returnTrue
4、定义hook函数
hook()
函数是register_forward_hook()
函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。
首先定义几个容器用于记录:
定义用于获取网络各层输入输出tensor的容器:
#并定义module_name用于记录相应的module名字module_name=[]features_in_hook=[]features_out_hook=[]hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:
hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字
defhook(module,fea_in,fea_out):print("hookerworking")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)returnNone
5、对需要的层注册hook
注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):
注册钩子可以对某些层单独进行:
net=TestForHook()net_chilren=net.children()forchildinnet_chilren:ifnotisinstance(child,nn.ReLU6):child.register_forward_hook(hook=hook)
6、测试forward()返回的特征和hook记录的是否一致
6.1 测试forward()提供的输入输出特征
由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:
out,features_in_forward,features_out_forward=net(x)print("*"*5+"forwardreturnfeatures"+"*"*5)print(features_in_forward)print(features_out_forward)print("*"*5+"forwardreturnfeatures"+"*"*5)
得到下面的输出是理所当然的:
*****forward return features*****
(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****
6.2 hook记录的输入特征和输出特征
hook通过list结构进行记录,所以可以直接print
测试features_in是不是存储了输入:
print("*"*5+"hookrecordfeatures"+"*"*5)print(features_in_hook)print(features_out_hook)print(module_name)print("*"*5+"hookrecordfeatures"+"*"*5)
得到和forward一样的结果:
*****hook record features*****
[(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****
6.3 把hook记录的和forward做减法
如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook
记录的特征和forward记录的特征做减法:
测试forward返回的feautes_in是不是和hook记录的一致:
print("subresult'")forforward_return,hook_recordinzip(features_in_forward,features_in_hook):print(forward_return-hook_record[0])
得到的全部都是0,说明hook没问题:
subresulttensor([[0.,0.],[0.,0.]])tensor([[0.,0.],[0.,0.]],grad_fn=<SubBackward0>)tensor([[0.],[0.]],grad_fn=<SubBackward0>)
7、完整代码
importtorchimporttorch.nnasnnclassTestForHook(nn.Module):def__init__(self):super().__init__()self.linear_1=nn.Linear(in_features=2,out_features=2)self.linear_2=nn.Linear(in_features=2,out_features=1)self.relu=nn.ReLU()self.relu6=nn.ReLU6()self.initialize()defforward(self,x):linear_1=self.linear_1(x)linear_2=self.linear_2(linear_1)relu=self.relu(linear_2)relu_6=self.relu6(relu)layers_in=(x,linear_1,linear_2)layers_out=(linear_1,linear_2,relu)returnrelu_6,layers_in,layers_outdefinitialize(self):"""定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight=torch.nn.Parameter(torch.FloatTensor([[1,1],[1,1]]))self.linear_1.bias=torch.nn.Parameter(torch.FloatTensor([1,1]))self.linear_2.weight=torch.nn.Parameter(torch.FloatTensor([[1,1]]))self.linear_2.bias=torch.nn.Parameter(torch.FloatTensor([1]))returnTrue
定义用于获取网络各层输入输出tensor
的容器,并定义module_name
用于记录相应的module名字
module_name=[]features_in_hook=[]features_out_hook=[]
hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字
defhook(module,fea_in,fea_out):print("hookerworking")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)returnNone
定义全部是1的输入:
x=torch.FloatTensor([[0.1,0.1],[0.1,0.1]])
注册钩子可以对某些层单独进行:
net=TestForHook()net_chilren=net.children()forchildinnet_chilren:ifnotisinstance(child,nn.ReLU6):child.register_forward_hook(hook=hook)
测试网络输出:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)
测试features_in是不是存储了输入:
print("*"*5+"hookrecordfeatures"+"*"*5)print(features_in_hook)print(features_out_hook)print(module_name)print("*"*5+"hookrecordfeatures"+"*"*5)
测试forward返回的feautes_in是不是和hook记录的一致:
print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
print(forward_return-hook_record[0])
“pytorch中的hook机制是什么”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注恰卡编程网网站,小编将为大家输出更多高质量的实用文章!
推荐阅读
-
Pytorch中model.eval()的作用是什么
Pytorch中model.eval()的作用是什么这篇文章主要介...
-
怎么使用pytorch读取数据集
-
React中常用的两个Hook是什么
React中常用的两个Hook是什么这篇文章给大家分享的是有关Re...
-
pytorch中的view()函数怎么使用
pytorch中的view()函数怎么使用这篇文章主要介绍了pyt...
-
PyTorch中的torch.cat怎么用
PyTorch中的torch.cat怎么用这篇文章主要介绍PyTo...
-
pytorch中的广播语义是什么
pytorch中的广播语义是什么这篇文章主要介绍“pytorch中...
-
PyTorch梯度下降反向传播实例分析
-
python中的Pytorch建模流程是什么
python中的Pytorch建模流程是什么小编给大家分享一下py...
-
pytorch如何实现多项式回归
-
pytorch怎样实现线性回归