本文主要是介绍莫凡Pytorch学习笔记(五),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Pytorch模型保存与提取
本篇笔记主要对应于莫凡Pytorch中的3.4节。主要讲了如何使用Pytorch保存和提取我们的神经网络。
我们将通过两种方式展示模型的保存和提取。
第一种保存方式是保存整个模型,在重新提取时直接加载整个模型。第二种保存方法是只保存模型的参数,这种方式只保存了参数,而不会保存模型的结构等信息。
两种方式各有优缺点。保存完整模型不需要知道网络的结构,一次性保存一次性读入。缺点是模型比较大时耗时较长,保存的文件也大。而只保存参数的方式存储快捷,保存的文件也小一些,但缺点是丢失了网络的结构信息,恢复模型时需要提前建立一个特定结构的网络再读入参数。
以下使用代码展示。
数据生成与展示
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
这里还是生成一组带有噪声的 y = x 2 y=x^{2} y=x2数据进行回归拟合。
# torch.manual_seed(1) # reproducible# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
基本网络搭建与保存
我们使用nn.Sequential模块来快速搭建一个网络完成回归操作。这里使用两种方式进行保存。
def save():# save net1net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)loss_func = torch.nn.MSELoss()for step in range(100):prediction = net1(x)loss = loss_func(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# plot resultplt.figure(1, figsize=(10, 3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.savefig("./img/05_save.png")torch.save(net1, 'net.pkl') # entire networktorch.save(net1.state_dict(), 'net_params.pkl') # parameters
在这个save函数中,我们首先使用nn.Sequential模块构建了一个基础的二层神经网络。然后对其进行训练。展示训练结果。之后使用两种方式进行保存。
第一种方式直接保存整个网络,代码为
torch.save(net1, 'net.pkl') # entire network
第二种方式只保存网络参数,代码为
torch.save(net1.state_dict(), 'net_params.pkl') # parameters
对保存的模型进行提取恢复
这里我们为两种不同存储方式保存的模型分别定义恢复提取的函数
首先是对整个网络的提取。直接使用torch.load就可以。
def restore_net():# 提取神经网络net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net2')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.savefig("./img/05_res_net.png")
而对于参数的读取,我们首先需要先搭建好一个与之前保存的模型相同架构的网络。然后使用这个网络的load_state_dict方法进行参数读取和恢复。
def restore_params():# 提取神经网络net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),torch.nn.ReLU(),torch.nn.Linear(10, 1))net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net3')plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)plt.savefig("./img/05_res_para.png")plt.show()
对比不同提取方法的效果
接下来我们对比一下这两种方法的提取效果
# save net1
save()# restore entire net (may slow)
restore_net()# restore only the net parameters
restore_params()
最后,得到的展示输出如下:
这里Net1即我们训练好的网络,我们使用两种方式保存了Net1。使用第一种方式存储和提取的结果为Net2,使用第二种方式存储和提取的结果为Net3。通过对比可以看出,这三个网络一模一样,证明不同的存储提取方式的效果是相同的,不会有差异。
参考
- 莫凡Python:Pytorch动态神经网络,https://mofanpy.com/tutorials/machine-learning/torch/
这篇关于莫凡Pytorch学习笔记(五)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!