利用PyTorch构建三层线性网络完成对MNIST数据集识别

2024-01-02 03:48

本文主要是介绍利用PyTorch构建三层线性网络完成对MNIST数据集识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里首先简单介绍一些MNIST数据集:

MNIST数据集内共包含70000张手写数字图像,数字范围0~9,大小为28*28,其中60000张用于训练学习,10000张用于数据测试,图像为灰度图像,数字位置居中,可以减少预处理和加快运行

在学习编程入门时,无论哪个语言,hello world往往是第一步,再进行深度学习入门时,MNIST数据集研究透了,基本就可以入门了。

下面向大家展示一下MNIST数据集内图像:

--------------------------------------------------------------------------------------------------------------------------------

下面进入今天的正题:利用PyTorch搭建一个三层线性网络,完成对MNIST数据集的训练并且进行测试 :

本次demo包含两个目录文件,一个是utils.py,另一个是mnist_train.py,在utils.py内我们放置了三个函数,分别是plot_curve,plot_image和one_hot。

plot_curve:用于绘制对MNIST数据集进行训练时损失函数曲线,方便观察

def plot_curve(data):fig = plt.figure()plt.plot(range(len(data)),data,color = 'blue')plt.legend(['value'],loc = 'upper right')plt.xlabel('step')plt.ylabel('value')plt.show()

plot_image:对于训练和识别过程,可以很方便的将训练结果可视化

def plot_image(img,label,name):fig = plt.figure()for i in range(6):plt.subplot(2,3,i+1)plt.tight_layout()plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')plt.title("{}:{}".format(name,label[i].item()))plt.xticks([])plt.yticks([])plt.show()

one_hot:PyTorch内还没有对one-hot函数的实现,在这用scatter完成简单的一个编码

注:one_hot编码:

One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。

One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。

def one_hot(label,depth=10):out = torch.zeros(label.size(0),depth)idx = torch.LongTensor(label).view(-1,1)out.scatter_(dim = 1,index = idx,value = 1)return out

这样,我们的utils.py里面的三个工具函数就已经编码完毕,这三个函数只是达到辅助可视化的作用,不会对训练和测试产生任何影响,所以大家如果图省事,放在一个函数里也可以

下面我们实现MNIST的train:

1.导入相关包:

import torch
from torch import nn
from torch.nn import functional as F
from torch import  optim
import torchvision
from utils import plot_image,plot_curve,one_hot

2.准备数据集: 

在这儿我们使用DataLoader完成数据集的下载,是一种十分方便的方式,这个地方是通过torch vision实现的。

torchvision是PyTorch的一个图形库,服务于PyTorch深度学习框架,构建计算机视觉模型

torchvision.transforms:常用的图像预处理方法,利用Compose将对图片的操作整合起来

torchvision.datasets:常用的datasets数据集实现,如MNIST,CIFAR10等

torchvision.model:常用的模型预训练,如LeNet,ResNet,VGG等

解释一下这里的ToTensor:数据归一化到均值为0,方差为1(是将数据除以255),即图像进来以后,先进行通道转换,然后判断图像类型,若是uint8类型,就除以255;否则返回原图。

而这里的Normalize是对数据按通道进行标准化,即减去均值,再除以方差

其中,0.1307和0.3081是mnist数据集的均值和标准差,因为mnist数据值都是灰度图,所以图像的通道数只有一个,因此均值和标准差各一个。要是imagenet数据集的话,由于它的图像都是RGB图像,因此他们的均值和标准差各3个,分别对应其R,G,B值。例如([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])就是Imagenet dataset的标准化系数(RGB三个通道对应三组系数)。数据集给出的均值和标准差系数,每个数据集都不同的,都是数据集提供方给出的。

transforms.Normalize(mean,std)的计算公式是:input=\frac{input - mean}{std}

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)

因为对RGB图片而言,数据范围是[0-255]的,需要先经过ToTensor除以255归一化到[0,1]之后,再通过Normalize计算过后,将数据归一化到[-1,1]。

那transform.Normalize()是怎么工作的呢?以上面代码为例,ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1)

3.创建网络

class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return x

这里采用的是最最最基本的线性层连接,共有三层,激活函数采用的relu函数

在这里Linear会把28*28的图像铺平展开为784个元素的一维数组后进行处理,在forward内不断传给下一层。

这个地方只写了前向传播的forward(),并没有写反向传播的backward()是因为在pytorch的求导过程中,有以下两种情况:

如果是标量对向量求导(scalar对tensor求导),那么就可以保证上面的计算图的根节点只有一个,此时不用引入grad_tensors参数,直接调用backward函数即可
如果是(向量)矩阵对(向量)矩阵求导(tensor对tensor求导),实际上是先求出Jacobian矩阵中每一个元素的梯度值(每一个元素的梯度值的求解过程对应上面的计算图的求解方法),然后将这个Jacobian矩阵与grad_tensors参数对应的矩阵进行对应的点乘,得到最终的结果。

4.Train

for epoch in range(5):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot)   #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)

经过五轮peoch完成对MNIST60000张图片的训练,并将每轮结果打印出来,将损失函数记录,并调用plot_curve展现train_loss下降折线图

5.Test

total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

 预测值pred取结果中间概率最大的index值作为他的label,可以用过argmax返回最大值的索引,上述代码即可以理解为在dim=1处,取最大值索引

而正确值correct是将y和pred之间做一个比较,利用sum()可得到当前batch中预测结果正确的一个总个数,最终是Tensor类型,再将其转换为数据类型,加上item(),最后total_correct累加

之后就是调用工具函数将数据可视化

mnist_train.py完整代码如下:

import torch
from torch import nn
from torch.nn import functional as F
from torch import  optim
import torchvision
from utils import plot_image,plot_curve,one_hotbatch_size = 512   #GPU单次运行处理图片的数量         批处理大小
#step1.load mnist
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = True)
test_load = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size = batch_size,shuffle = False)x,y = next(iter((train_loader)))
print(x.shape,y.shape,x.min(),x.max())
plot_image(x,y,'image sample')#step2.create network
class Net(nn.Module):def __init__(self):super(Net, self).__init__()#xw+bself.fc1 = nn.Linear(28*28,256)self.fc2 = nn.Linear(256,64)self.fc3 = nn.Linear(64,10)def forward(self,x):# x:[b,1,28,28]# h1:relu(xw1+b1)x = F.relu(self.fc1(x))# h2:relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net()
#[w1,w2,w3,b1,b2,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)train_loss = []for epoch in range(3):for batch_idx,(x,y) in enumerate(train_loader):# x:[b,1,28,28] , y:[512]# [b,1,28,28] => [b,784]x = x.view(x.size(0),28*28)# => [b,10]out = net(x)y_onehot = one_hot(y)#loss = mse(out,y_onehot)loss = F.mse_loss(out,y_onehot)   #均方差#清零梯度optimizer.zero_grad()loss.backward()# w' = w -lr * gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10 ==0:print(epoch,batch_idx,loss.item())
plot_curve(train_loss)
# we get optimal [w1,b1,w2,b2,w3,b3]total_correct = 0
for x,y in test_load:x = x.view(x.size(0),28*28)out = net(x)#out :[b,10] => pred:[b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_load.dataset)
acc = total_correct / total_num
print('test acc:',acc)x,y = next(iter(test_load))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim=1)
plot_image(x,pred,'test')

下面将结果展示一下

上图为读取mnist_train中的数据,展示图片以及sample

 

 这是损失函数train_loss的图像,可以看到,随着学习的深入,损失值是在不断下降的,最后逐渐趋于稳定,但是本次demo只是利用最简单的三层结构,且利用的SGD,在其它方法的训练下,可能会获得更好的训练效果

识别测试结果,可以看到,准确率还是比较高的,最后的acc大概达到了接近90%

这就是本次的小demo,有很多理解自己也不算吃的很透,随着学习的深入会理解的更加透彻

最后,感谢各位看官,欢迎批评,互相学习进步! 

 

这篇关于利用PyTorch构建三层线性网络完成对MNIST数据集识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561262

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

异构存储(冷热数据分离)

异构存储主要解决不同的数据,存储在不同类型的硬盘中,达到最佳性能的问题。 异构存储Shell操作 (1)查看当前有哪些存储策略可以用 [lytfly@hadoop102 hadoop-3.1.4]$ hdfs storagepolicies -listPolicies (2)为指定路径(数据存储目录)设置指定的存储策略 hdfs storagepolicies -setStoragePo

Hadoop集群数据均衡之磁盘间数据均衡

生产环境,由于硬盘空间不足,往往需要增加一块硬盘。刚加载的硬盘没有数据时,可以执行磁盘数据均衡命令。(Hadoop3.x新特性) plan后面带的节点的名字必须是已经存在的,并且是需要均衡的节点。 如果节点不存在,会报如下错误: 如果节点只有一个硬盘的话,不会创建均衡计划: (1)生成均衡计划 hdfs diskbalancer -plan hadoop102 (2)执行均衡计划 hd

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了