本文主要是介绍利用PyTorch构建三层线性网络完成对MNIST数据集识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
在这里首先简单介绍一些MNIST数据集:
MNIST数据集内共包含70000张手写数字图像,数字范围0~9,大小为28*28,其中60000张用于训练学习,10000张用于数据测试,图像为灰度图像,数字位置居中,可以减少预处理和加快运行
在学习编程入门时,无论哪个语言,hello world往往是第一步,再进行深度学习入门时,MNIST数据集研究透了,基本就可以入门了。
下面向大家展示一下MNIST数据集内图像:
--------------------------------------------------------------------------------------------------------------------------------
下面进入今天的正题:利用PyTorch搭建一个三层线性网络,完成对MNIST数据集的训练并且进行测试 :
本次demo包含两个目录文件,一个是utils.py,另一个是mnist_train.py,在utils.py内我们放置了三个函数,分别是plot_curve,plot_image和one_hot。
plot_curve:用于绘制对MNIST数据集进行训练时损失函数曲线,方便观察
def plot_curve(data):fig = plt.figure()plt.plot(range(len(data)),data,color = 'blue')plt.legend(['value'],loc = 'upper right')plt.xlabel('step')plt.ylabel('value')plt.show()
plot_image:对于训练和识别过程,可以很方便的将训练结果可视化
def plot_image(img,label,name):fig = plt.figure()for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')plt.title("{}:{}".format(name,label[i].item()))plt.xticks([])plt.yticks([])plt.show()
one_hot:PyTorch内还没有对one-hot函数的实现,在这用scatter完成简单的一个编码
注:one_hot编码:
One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。
One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。
def one_hot(label,depth=10):out = torch.zeros(label.size(0),depth)idx = torch.LongTensor(label).view(-1,1)out.scatter_(dim = 1,index = idx,value = 1)return out
这样,我们的utils.py里面的三个工具函数就已经编码完毕,这三个函数只是达到辅助可视化的作用,不会对训练和测试产生任何影响,所以大家如果图省事,放在一个函数里也可以
下面我们实现MNIST的train:
1.导入相关包:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from utils import plot_image,plot_curve,one_hot
2.准备数据集:
在这儿我们使用DataLoader完成数据集的下载,是一种十分方便的方式,这个地方是通过torch vision实现的。
torchvision是PyTorch的一个图形库,服务于PyTorch深度学习框架,构建计算机视觉模型
torchvision.transforms:常用的图像预处理方法,利用Compose将对图片的操作整合起来
torchvision.datasets:常用的datasets数据集实现,如MNIST,CIFAR10等
torchvision.model:常用的模型预训练,如LeNet,ResNet,VGG等
解释一下这里的ToTensor:数据归一化到均值为0,方差为1(是将数据除以255),即图像进来以后,先进行通道转换,然后判断图像类型,若是uint8类型,就除以255;否则返回原图。
而这里的Normalize是对数据按通道进行标准化,即减去均值,再除以方差
其中,0.1307和0.3081是mnist数据集的均值和标准差,因为mnist数据值都是灰度图,所以图像的通道数只有一个,因此均值和标准差各一个。要是imagenet数据集的话,由于它的图像都是RGB图像,因此他们的均值和标准差各3个,分别对应其R,G,B值。例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是Imagenet dataset的标准化系数(RGB三个通道对应三组系数)。数据集给出的均值和标准差系数,每个数据集都不同的,都是数据集提供方给出的。
transforms.Normalize(mean,std)的计算公式是:
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)
因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。
那transform.Normalize()是怎么工作的呢?以上面代码为例,ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1)
3.创建网络
class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return x
这里采用的是最最最基本的线性层连接,共有三层,激活函数采用的relu函数
在这里Linear会把28*28的图像铺平展开为784个元素的一维数组后进行处理,在forward内不断传给下一层。
这个地方只写了前向传播的forward(),并没有写反向传播的backward()是因为在pytorch的求导过程中,有以下两种情况:
如果是标量对向量求导(scalar对tensor求导),那么就可以保证上面的计算图的根节点只有一个,此时不用引入grad_tensors参数,直接调用backward函数即可
如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。
4.Train
for epoch in range(5):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot) #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
经过五轮peoch完成对MNIST60000张图片的训练,并将每轮结果打印出来,将损失函数记录,并调用plot_curve展现train_loss下降折线图
5.Test
total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')
预测值pred取结果中间概率最大的index值作为他的label,可以用过argmax返回最大值的索引,上述代码即可以理解为在dim=1处,取最大值索引
而正确值correct是将y和pred之间做一个比较,利用sum()可得到当前batch中预测结果正确的一个总个数,最终是Tensor类型,再将其转换为数据类型,加上item(),最后total_correct累加
之后就是调用工具函数将数据可视化
mnist_train.py完整代码如下:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from utils import plot_image,plot_curve,one_hotbatch_size = 512 #GPU单次运行处理图片的数量 批处理大小
#step1.load mnist
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)x,y = next(iter((train_loader)))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')#step2.create network
class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net()
#[w1,w2,w3,b1,b2,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)train_loss = []for epoch in range(3):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot) #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
# we get optimal [w1,b1,w2,b2,w3,b3]total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')
下面将结果展示一下
上图为读取mnist_train中的数据,展示图片以及sample
这是损失函数train_loss的图像,可以看到,随着学习的深入,损失值是在不断下降的,最后逐渐趋于稳定,但是本次demo只是利用最简单的三层结构,且利用的SGD,在其它方法的训练下,可能会获得更好的训练效果
识别测试结果,可以看到,准确率还是比较高的,最后的acc大概达到了接近90%
这就是本次的小demo,有很多理解自己也不算吃的很透,随着学习的深入会理解的更加透彻
最后,感谢各位看官,欢迎批评,互相学习进步!
这篇关于利用PyTorch构建三层线性网络完成对MNIST数据集识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!