本文主要是介绍Error(s) in loading state_dict for XXX Unexpected key(s) in state_dict, 找不到num_batches_tracked,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
今天在训练的时候发现加载模型的时候提示找不到num_batches_tracked,感到奇怪,因为之前已经成功训练过一次了怎么这次就报错了呢,后来发现,第一次训练的时候我用的是0.4.0的pytorch,这次用的是1.0的Pytorch,因为torch的版本不一样引起的问题
KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'
得到类似这样的报错
以下参考自这篇文章 https://zhuanlan.zhihu.com/p/91485607
经过研究发现,在pytorch 0.4.1及后面的版本里,BatchNorm层新增了num_batches_tracked参数,用来统计训练时的forward过的batch数目,源码如下(pytorch0.4.1):
if self.training and self.track_running_stats:self.num_batches_tracked += 1if self.momentum is None: # use cumulative moving averageexponential_average_factor = 1.0 / self.num_batches_tracked.item()else: # use exponential moving averageexponential_average_factor = self.momentum
知道原因就知道怎么处理了,我自己的模型里没有num_batches_tracked这个键,要把我预训练模型里的这个键给剔除掉
这是我对我文件里做的修改,注释掉的那行是原来的代码,可以对比一下 新增加的三行和原来的这行,就是简单的做了一个字典删除
这篇关于Error(s) in loading state_dict for XXX Unexpected key(s) in state_dict, 找不到num_batches_tracked的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!