利用PyTorch构建三层线性网络完成对MNIST数据集识别

2024-01-02 03:48

本文主要是介绍利用PyTorch构建三层线性网络完成对MNIST数据集识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里首先简单介绍一些MNIST数据集:

MNIST数据集内共包含70000张手写数字图像,数字范围0~9,大小为28*28,其中60000张用于训练学习,10000张用于数据测试,图像为灰度图像,数字位置居中,可以减少预处理和加快运行

在学习编程入门时,无论哪个语言,hello world往往是第一步,再进行深度学习入门时,MNIST数据集研究透了,基本就可以入门了。

下面向大家展示一下MNIST数据集内图像:

--------------------------------------------------------------------------------------------------------------------------------

下面进入今天的正题:利用PyTorch搭建一个三层线性网络,完成对MNIST数据集的训练并且进行测试 :

本次demo包含两个目录文件,一个是utils.py,另一个是mnist_train.py,在utils.py内我们放置了三个函数,分别是plot_curve,plot_image和one_hot。

plot_curve:用于绘制对MNIST数据集进行训练时损失函数曲线,方便观察

def plot_curve(data):fig = plt.figure()plt.plot(range(len(data)),data,color = 'blue')plt.legend(['value'],loc = 'upper right')plt.xlabel('step')plt.ylabel('value')plt.show()

plot_image:对于训练和识别过程,可以很方便的将训练结果可视化

def plot_image(img,label,name):fig = plt.figure()for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')plt.title("{}:{}".format(name,label[i].item()))plt.xticks([])plt.yticks([])plt.show()

one_hot:PyTorch内还没有对one-hot函数的实现,在这用scatter完成简单的一个编码

注:one_hot编码:

One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。

One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。

def one_hot(label,depth=10):out = torch.zeros(label.size(0),depth)idx = torch.LongTensor(label).view(-1,1)out.scatter_(dim = 1,index = idx,value = 1)return out

这样,我们的utils.py里面的三个工具函数就已经编码完毕,这三个函数只是达到辅助可视化的作用,不会对训练和测试产生任何影响,所以大家如果图省事,放在一个函数里也可以

下面我们实现MNIST的train:

1.导入相关包:

import torch
from torch import nn
from torch.nn import functional as F
from torch import  optim
import torchvision
from utils import plot_image,plot_curve,one_hot

2.准备数据集: 

在这儿我们使用DataLoader完成数据集的下载,是一种十分方便的方式,这个地方是通过torch vision实现的。

torchvision是PyTorch的一个图形库,服务于PyTorch深度学习框架,构建计算机视觉模型

torchvision.transforms:常用的图像预处理方法,利用Compose将对图片的操作整合起来

torchvision.datasets:常用的datasets数据集实现,如MNIST,CIFAR10等

torchvision.model:常用的模型预训练,如LeNet,ResNet,VGG等

解释一下这里的ToTensor:数据归一化到均值为0,方差为1(是将数据除以255),即图像进来以后,先进行通道转换,然后判断图像类型,若是uint8类型,就除以255;否则返回原图。

而这里的Normalize是对数据按通道进行标准化,即减去均值,再除以方差

其中,0.1307和0.3081是mnist数据集的均值和标准差,因为mnist数据值都是灰度图,所以图像的通道数只有一个,因此均值和标准差各一个。要是imagenet数据集的话,由于它的图像都是RGB图像,因此他们的均值和标准差各3个,分别对应其R,G,B值。例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是Imagenet dataset的标准化系数(RGB三个通道对应三组系数)。数据集给出的均值和标准差系数,每个数据集都不同的,都是数据集提供方给出的。

transforms.Normalize(mean,std)的计算公式是:input=\frac{input - mean}{std}

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)

因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。

那transform.Normalize()是怎么工作的呢?以上面代码为例,ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1)

3.创建网络

class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return x

这里采用的是最最最基本的线性层连接,共有三层,激活函数采用的relu函数

在这里Linear会把28*28的图像铺平展开为784个元素的一维数组后进行处理,在forward内不断传给下一层。

这个地方只写了前向传播的forward(),并没有写反向传播的backward()是因为在pytorch的求导过程中,有以下两种情况:

如果是标量对向量求导(scalar对tensor求导),那么就可以保证上面的计算图的根节点只有一个,此时不用引入grad_tensors参数,直接调用backward函数即可
如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。

4.Train

for epoch in range(5):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot)   #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)

经过五轮peoch完成对MNIST60000张图片的训练,并将每轮结果打印出来,将损失函数记录,并调用plot_curve展现train_loss下降折线图

5.Test

total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

 预测值pred取结果中间概率最大的index值作为他的label,可以用过argmax返回最大值的索引,上述代码即可以理解为在dim=1处,取最大值索引

而正确值correct是将y和pred之间做一个比较,利用sum()可得到当前batch中预测结果正确的一个总个数,最终是Tensor类型,再将其转换为数据类型,加上item(),最后total_correct累加

之后就是调用工具函数将数据可视化

mnist_train.py完整代码如下:

import torch
from torch import nn
from torch.nn import functional as F
from torch import  optim
import torchvision
from utils import plot_image,plot_curve,one_hotbatch_size = 512   #GPU单次运行处理图片的数量         批处理大小
#step1.load mnist
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)x,y = next(iter((train_loader)))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')#step2.create network
class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net()
#[w1,w2,w3,b1,b2,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)train_loss = []for epoch in range(3):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot)   #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
# we get optimal [w1,b1,w2,b2,w3,b3]total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

下面将结果展示一下

上图为读取mnist_train中的数据,展示图片以及sample

 

 这是损失函数train_loss的图像,可以看到,随着学习的深入,损失值是在不断下降的,最后逐渐趋于稳定,但是本次demo只是利用最简单的三层结构,且利用的SGD,在其它方法的训练下,可能会获得更好的训练效果

识别测试结果,可以看到,准确率还是比较高的,最后的acc大概达到了接近90%

这就是本次的小demo,有很多理解自己也不算吃的很透,随着学习的深入会理解的更加透彻

最后,感谢各位看官,欢迎批评,互相学习进步! 

 

这篇关于利用PyTorch构建三层线性网络完成对MNIST数据集识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561262

相关文章

Python获取中国节假日数据记录入JSON文件

《Python获取中国节假日数据记录入JSON文件》项目系统内置的日历应用为了提升用户体验,特别设置了在调休日期显示“休”的UI图标功能,那么问题是这些调休数据从哪里来呢?我尝试一种更为智能的方法:P... 目录节假日数据获取存入jsON文件节假日数据读取封装完整代码项目系统内置的日历应用为了提升用户体验,

Java利用JSONPath操作JSON数据的技术指南

《Java利用JSONPath操作JSON数据的技术指南》JSONPath是一种强大的工具,用于查询和操作JSON数据,类似于SQL的语法,它为处理复杂的JSON数据结构提供了简单且高效... 目录1、简述2、什么是 jsONPath?3、Java 示例3.1 基本查询3.2 过滤查询3.3 递归搜索3.4

MySQL大表数据的分区与分库分表的实现

《MySQL大表数据的分区与分库分表的实现》数据库的分区和分库分表是两种常用的技术方案,本文主要介绍了MySQL大表数据的分区与分库分表的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. mysql大表数据的分区1.1 什么是分区?1.2 分区的类型1.3 分区的优点1.4 分

一文详解如何从零构建Spring Boot Starter并实现整合

《一文详解如何从零构建SpringBootStarter并实现整合》SpringBoot是一个开源的Java基础框架,用于创建独立、生产级的基于Spring框架的应用程序,:本文主要介绍如何从... 目录一、Spring Boot Starter的核心价值二、Starter项目创建全流程2.1 项目初始化(

Mysql删除几亿条数据表中的部分数据的方法实现

《Mysql删除几亿条数据表中的部分数据的方法实现》在MySQL中删除一个大表中的数据时,需要特别注意操作的性能和对系统的影响,本文主要介绍了Mysql删除几亿条数据表中的部分数据的方法实现,具有一定... 目录1、需求2、方案1. 使用 DELETE 语句分批删除2. 使用 INPLACE ALTER T

Python Dash框架在数据可视化仪表板中的应用与实践记录

《PythonDash框架在数据可视化仪表板中的应用与实践记录》Python的PlotlyDash库提供了一种简便且强大的方式来构建和展示互动式数据仪表板,本篇文章将深入探讨如何使用Dash设计一... 目录python Dash框架在数据可视化仪表板中的应用与实践1. 什么是Plotly Dash?1.1

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Redis 中的热点键和数据倾斜示例详解

《Redis中的热点键和数据倾斜示例详解》热点键是指在Redis中被频繁访问的特定键,这些键由于其高访问频率,可能导致Redis服务器的性能问题,尤其是在高并发场景下,本文给大家介绍Redis中的热... 目录Redis 中的热点键和数据倾斜热点键(Hot Key)定义特点应对策略示例数据倾斜(Data S

Python实现将MySQL中所有表的数据都导出为CSV文件并压缩

《Python实现将MySQL中所有表的数据都导出为CSV文件并压缩》这篇文章主要为大家详细介绍了如何使用Python将MySQL数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到... python将mysql数据库中所有表的数据都导出为CSV文件到一个目录,并压缩为zip文件到另一个