本文主要是介绍pytorch使用DataParallel并行化保存和加载模型(单卡、多卡各种情况讲解),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
话不多说,直接进入正题。
!!!不过要注意一点,本文保存模型采用的都是只保存模型参数的情况,而不是保存整个模型的情况。一定要看清楚再用啊!
1 单卡训练,单卡加载
#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()#MyModel()是你定义的创建模型的函数,就是先初始化得到一个模型实例,之后再将模型参数加载到该实例上
model.load_state_dict(torch.load('model.pt'))
2 单卡训练,多卡加载
保存模型的过程同第一种情况一样,但是要注意,多卡加载模型时, 是先加载模型参数,再对模型做并行化处理。
#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()
model.load_state_dict(torch.load('model.pt'))model=nn.DataParallel(model)#将模型进行并行化处理
3 多卡保存,单卡加载
方法一:
考虑到之后可能需要单卡加载你多卡训练的模型,所以建议在保存的时候,要去除模型参数字典里面的module,即使用model.module.state_dict()代替model.state_dict()来进行去除。
因为是单卡加载,所以还是要先加载 模型参数,再对模型做并行化处理。
#保存模型
torch.save(model.module.state_dict(),'modle.pt')#加载模型
model=MyModel()
model.load_state_dict(torch.load('model.pt'))model=nn.DataParallel(model)
方法二:
仍然使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行),加载的时候要注意:由于我们保存到 方式是以多卡方式保存的,所以无论加载之后的模型是 在答案卡上运行还是在多卡上运行,都要先把模型并行化处理,然后再去加载模型。
#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()model=nn.DataParallel(model)model.load_state_dict(torch.load('model.pt'))
4 多卡保存,多卡加载
这里保存模型采用”多卡保存,单卡加载“的第二种方法,加载的时候,要先把模型做并行化(在多卡上并行),然后再加载。
#保存模型
torch.save(model.state_dict(),'model.pt')#加载模型
model=MyModel()model=nn.DataParallel(model)model.load_state_dict(torch.load('model.pt'))
希望以上内容能够帮助到你,这里是希望你能越来越好的 小白冲鸭 ~~~
这篇关于pytorch使用DataParallel并行化保存和加载模型(单卡、多卡各种情况讲解)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!