Pytorch中实现只导入部分模型参数的方式
我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?
好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。
请看下面的一个例子。
我们先搭建一个小小的网络。
import torch as t from torch.nn import Module from torch import nn from torch.nn import functional as F class Net(Module): def __init__(self): super(Net,self).__init__() self.conv1 = nn.Conv2d(3,32,3,1) self.conv2 = nn.Conv2d(32,3,3,1) self.w = nn.Parameter(t.randn(3,10)) for p in self.children(): nn.init.xavier_normal_(p.weight.data) nn.init.constant_(p.bias.data, 0) def forward(self, x): out = self.conv1(x) out = self.conv2(x) out = F.avg_pool2d(out,(out.shape[2],out.shape[3])) out = F.linear(out,weight=self.w) return out
然后我们保存这个网络的初始值。
model = Net() t.save(model.state_dict(),'xxx.pth')
现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。
import torch as t from torch.nn import Module from torch import nn from torch.nn import functional as F class Net(Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 32, 3, 1) self.conv2 = nn.Conv2d(32, 3, 3, 1) self.conv3 = nn.Conv2d(3,64,3,1) self.conv4 = nn.Conv2d(64,32,3,1) for p in self.children(): nn.init.xavier_normal_(p.weight.data) nn.init.constant_(p.bias.data, 0) self.w = nn.Parameter(t.randn(3, 10)) def forward(self, x): out = self.conv1(x) out = self.conv2(x) out = F.avg_pool2d(out, (out.shape[2], out.shape[3])) out = F.linear(out, weight=self.w) return out
我们现在试着导入之前保存的模型参数。
path = 'xxx.pth' model = Net() model.load_state_dict(t.load(path)) ''' RuntimeError: Error(s) in loading state_dict for Net: Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". '''
出现了没有在模型文件中找到error中的关键字的错误。
现在我们这样导入模型
path = 'xxx.pth' model = Net() save_model = t.load(path) model_dict = model.state_dict() state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias']) model_dict.update(state_dict) model.load_state_dict(model_dict)
看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。
为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。
for k in state_dict.keys(): print(k) ''' w conv1.weight conv1.bias conv2.weight conv2.bias ''' for k in model_dict.keys(): print(k) ''' w conv1.weight conv1.bias conv2.weight conv2.bias conv3.weight conv3.bias conv4.weight conv4.bias '''
这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。
以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持猪先飞。
相关文章
pytorch nn.Conv2d()中的padding以及输出大小方式
今天小编就为大家分享一篇pytorch nn.Conv2d()中的padding以及输出大小方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27- 这篇文章主要介绍了PyTorch一小时掌握之迁移学习篇,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-09-08
Linux安装Pytorch1.8GPU(CUDA11.1)的实现
这篇文章主要介绍了Linux安装Pytorch1.8GPU(CUDA11.1)的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-03-25- 这篇文章主要介绍了Pytorch之扩充tensor的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-05
- 今天小编就为大家分享一篇pytorch 自定义卷积核进行卷积操作方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-06
详解在IDEA中将Echarts引入web两种方式(使用js文件和maven的依赖导入)
这篇文章主要介绍了在IDEA中将Echarts引入web两种方式(使用js文件和maven的依赖导入),本文通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2020-07-11- 这篇文章主要介绍了解决pytorch 交叉熵损失输出为负数的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-07-08
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
今天小编就为大家分享一篇pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch 实现冻结部分参数训练另一部分,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-27
从Pytorch模型pth文件中读取参数成numpy矩阵的操作
这篇文章主要介绍了从Pytorch模型pth文件中读取参数成numpy矩阵的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2021-03-04Pytorch 的损失函数Loss function使用详解
今天小编就为大家分享一篇Pytorch 的损失函数Loss function使用详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 今天小编就为大家分享一篇pytorch中的上采样以及各种反操作,求逆操作详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-30
- 今天小编就为大家分享一篇Pytorch实现LSTM和GRU示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
- 这篇文章主要给大家总结介绍了R语言导入导出数据的几种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧...2021-05-06
- 这篇文章主要介绍了基于Pytorch版yolov5的滑块验证码破解思路详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...2021-02-25
解决Pytorch dataloader时报错每个tensor维度不一样的问题
这篇文章主要介绍了解决Pytorch dataloader时报错每个tensor维度不一样的问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-28pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
今天小编就为大家分享一篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-05-02- 这篇文章主要介绍了pytorch深度学习中对softmax实现进行了详细解析,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步...2021-09-30
- 今天小编就为大家分享一篇Pytorch 计算误判率,计算准确率,计算召回率的例子,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...2020-04-27
- 这篇文章主要介绍了pytorch中的squeeze函数、cat函数使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教...2021-05-20