怎么在Pytorch中只导入部分模型参数
怎么在Pytorch中只导入部分模型参数?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。
pytorch的优点
1.PyTorch是相当简洁且高效快速的框架;2.设计追求最少的封装;3.设计符合人类思维,它让用户尽可能地专注于实现自己的想法;4.与google的Tensorflow类似,FAIR的支持足以确保PyTorch获得持续的开发更新;5.PyTorch作者亲自维护的论坛 供用户交流和求教问题6.入门简单
importtorchast fromtorch.nnimportModule fromtorchimportnn fromtorch.nnimportfunctionalasF classNet(Module): def__init__(self): super(Net,self).__init__() self.conv1=nn.Conv2d(3,32,3,1) self.conv2=nn.Conv2d(32,3,3,1) self.w=nn.Parameter(t.randn(3,10)) forpinself.children(): nn.init.xavier_normal_(p.weight.data) nn.init.constant_(p.bias.data,0) defforward(self,x): out=self.conv1(x) out=self.conv2(x) out=F.avg_pool2d(out,(out.shape[2],out.shape[3])) out=F.linear(out,weight=self.w) returnout
然后我们保存这个网络的初始值。
model=Net() t.save(model.state_dict(),'xxx.pth')
现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。
importtorchast fromtorch.nnimportModule fromtorchimportnn fromtorch.nnimportfunctionalasF classNet(Module): def__init__(self): super(Net,self).__init__() self.conv1=nn.Conv2d(3,32,3,1) self.conv2=nn.Conv2d(32,3,3,1) self.conv3=nn.Conv2d(3,64,3,1) self.conv4=nn.Conv2d(64,32,3,1) forpinself.children(): nn.init.xavier_normal_(p.weight.data) nn.init.constant_(p.bias.data,0) self.w=nn.Parameter(t.randn(3,10)) defforward(self,x): out=self.conv1(x) out=self.conv2(x) out=F.avg_pool2d(out,(out.shape[2],out.shape[3])) out=F.linear(out,weight=self.w) returnout
我们现在试着导入之前保存的模型参数。
path='xxx.pth' model=Net() model.load_state_dict(t.load(path)) ''' RuntimeError:Error(s)inloadingstate_dictforNet: Missingkey(s)instate_dict:"conv3.weight","conv3.bias","conv4.weight","conv4.bias". '''
出现了没有在模型文件中找到error中的关键字的错误。
现在我们这样导入模型
path='xxx.pth' model=Net() save_model=t.load(path) model_dict=model.state_dict() state_dict={k:vfork,vinsave_model.items()ifkinmodel_dict.keys()} print(state_dict.keys())#dict_keys(['w','conv1.weight','conv1.bias','conv2.weight','conv2.bias']) model_dict.update(state_dict) model.load_state_dict(model_dict)
看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。
为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。
forkinstate_dict.keys(): print(k) ''' w conv1.weight conv1.bias conv2.weight conv2.bias ''' forkinmodel_dict.keys(): print(k) ''' w conv1.weight conv1.bias conv2.weight conv2.bias conv3.weight conv3.bias conv4.weight conv4.bias '''
关于怎么在Pytorch中只导入部分模型参数问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注恰卡编程网行业资讯频道了解更多相关知识。
推荐阅读
-
Pytorch中model.eval()的作用是什么
Pytorch中model.eval()的作用是什么这篇文章主要介...
-
怎么使用pytorch读取数据集
-
pytorch中的view()函数怎么使用
pytorch中的view()函数怎么使用这篇文章主要介绍了pyt...
-
PyTorch中的torch.cat怎么用
PyTorch中的torch.cat怎么用这篇文章主要介绍PyTo...
-
pytorch中的hook机制是什么
pytorch中的hook机制是什么本篇内容介绍了“pytorch...
-
pytorch中的广播语义是什么
pytorch中的广播语义是什么这篇文章主要介绍“pytorch中...
-
PyTorch梯度下降反向传播实例分析
-
python中的Pytorch建模流程是什么
python中的Pytorch建模流程是什么小编给大家分享一下py...
-
pytorch如何实现多项式回归
-
pytorch怎样实现线性回归