基于LeNet5的手写数字识别神经网络

2024-01-02 03:48

本文主要是介绍基于LeNet5的手写数字识别神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于CNN,迄今为止已经提出了各种网络结构。在这里我们着重介绍一下在1998年首次被提出的CNN元组LeNet。LeNet子啊1998年被提出,是进行手写数字识别的网络,如下图所示,他又连续的卷积层和池化层(正确地讲,是只“抽选元素”的子采样层),最后经全连接输出结果。

在初始的LeNet中,输入时32*32的图像,经过卷积层输出channel为6,大小28*28的feature map,在经过子采样(Subsampling)池化后,将图像大小变为14*14,(stride=2)在进行卷积,output_channel变为16,大小10*10,在经过一层子采样池化,将图像最终变为5*5,传给全连接层,经过全连接层处理后输出。具体处理流程如下图:

和现在的CNN相比,LeNet有几个不同点。第一个不同点在于激活函数,LeNet中使用的是sigmoid函数 ,而现在的CNN中主要使用ReLU函数。此外,原始的LeNet中使用子采样(Subsampling)缩小中间数据大小,而现在的CNN中Max池化是主流。

下面我们完成一个基于LeNet5的网络对MNIST数据集的识别:

首先我们先建立数据集,在这里可以说利用datasets下载这样的简易数据集简直不要太好用

mnist_train = datasets.MNIST('MNIST',True,transform=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()]),download=True)mnist_train = DataLoader(mnist_train,batch_size=batch_size,shuffle=True)mnist_test = datasets.MNIST('MNIST',False,transform=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()]),download=True)mnist_test = DataLoader(mnist_test,batch_size=batch_size,shuffle=True)

我们对下载好的数据集进行输出,看看情况怎么样(batch_size = 32)

 x,label = iter(mnist_train).next()print('x:',x.shape,' label:',label.shape)#输出结果:x: torch.Size([32, 1, 28, 28])  label: torch.Size([32])

下面我们来建立一个LeNet网络:

class lenet5(nn.Module):"""for MNIST DATASET"""def __init__(self):super(lenet5, self).__init__()# convolutionsself.cov_unit = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,stride=1,padding=1),nn.MaxPool2d(kernel_size=2,stride=2,padding=0),nn.Conv2d(6,16,kernel_size=5,stride=1,padding=1),nn.MaxPool2d(kernel_size=2,stride=2,padding=0))#flattenself.fc_unit = nn.Sequential(nn.Linear(16*5*5,120),nn.ReLU(),nn.Linear(120,84),nn.ReLU(),nn.Linear(84,10))def forward(self,x):batchsz = x.size(0)x = self.cov_unit(x)x = x.view(batchsz,16*5*5)logits = self.fc_unit(x)return logits

在这里需要借鉴一下:LeNet论文阅读:LeNet结构以及参数个数计算_silent56_th的博客-CSDN博客icon-default.png?t=L892https://blog.csdn.net/silent56_th/article/details/53456522

博主的博客内对输入数据和隐藏层的参数分析以及为何不全采用全连接做了解释:我自己对于kernel_size部分的参数选定还有理解不到位的地方,在这里借鉴一下:

S1-C2对应关系

已经搭建好了LeNet网络,下面定义优化器和损失函数已经利用GPU进行加速:

device = torch.device('cuda')
model = lenet5().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-3)

 在这里我们可以将model在控制台打印出来,观察一下,在整体观察一下LeNet网络模型:

# 控制台打印输出
model: lenet5((cov_unit): Sequential((0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fc_unit): Sequential((0): Linear(in_features=400, out_features=120, bias=True)(1): ReLU()(2): Linear(in_features=120, out_features=84, bias=True)(3): ReLU()(4): Linear(in_features=84, out_features=10, bias=True))
)

 从这里可以看到,基本是我们所设置的一个网络模型,现在网络已经搭建完毕,优化器和参数都以设定,下面开始进行训练:

​
for batchidx,(x,label) in enumerate(mnist_train):x,label = x.to(device),label.to(device)logits = model(x)loss = criteon(logits,label)optimizer.zero_grad()loss.backward()optimizer.step()print('epoch:',epoch,' loss:',loss.item())

这里的logits原指sigmoid函数(标准logits函数),但是在这里用来表示最终全连接层输出,而非其本意。在每个epoch结束后,将损失函数loss的值打印在控制台

下面是进行的测试:

 model.eval()with torch.no_grad():total_num = 0total_correct = 0for x,label in mnist_test:x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum().item()total_num += x.size(0)acc = total_correct/total_numprint('epoch:',epoch,' accuarcy:',acc)

 这里的pred = logits.argmax(dim=1),argmax函数是返回最大值的索引,即经过训练后预测结果概率最大的索引,这里将pred和监督标签label进行比较,如果equal便加到total_correct中,最后计算acc。

在进行训练和测试之前,分别添加了model.train()以及model.eval()

(1). model.train()
启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
(2). model.eval()
不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False

BatchNormalization的思路是调整各层的激活值分布使其拥有适当的广度,要向神经网络中插入数据分布进行正规化的层,可以使学习快速进行、不那么依赖初始值同时还可以一定程度抑制过拟合

Dropout是一种在学习过程中随机删除神经元的方法,通过随机选择并删除神经元,停止向前传递信号,使用Dropout可以使训练数据和测试数据的识别精度的差距变小了,即使是表现力很强的网络,也可以抑制过拟合。

最后,我们为了使数据可以更好的展现和反馈,我们利用visdom进行可视化

viz = Visdom()
viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))
global_step = 0

对于train_loss,从[0,0]坐标开始,每一个epoch执行完,global_step += 1 

  global_step += 1viz.line([loss.item()],[global_step],win='train_loss', update='append')

 将本次epoch内计算的loss以折线图的方式绘制

viz.images(x.view(-1, 1, 28, 28), win='x')
viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))

此时x.shape为[16,1,28,28],str(pred.detach().cpu.numpy())是将预测值变为数据类型打印出来

经过15个epoch我们可以看到,这个识别的精确度已经很高了,我们在看一下visdom可视化的结果:

 也是可以看到的,虽然略有起伏,但是train_loss还是在逐步下降的,我们抽取了10个数据进行展示,可以看到, 预测的结果也是十分准确的。

Conclusion:LeNet只是在1998年最早提出来的CNN,与现在的CNN虽然有些许不同,但是差别也不是很大,考虑到提出的时间很早,所以LeNet还是十分令人称奇的

这篇关于基于LeNet5的手写数字识别神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561264

相关文章

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

AIGC6: 走进腾讯数字盛会

图中是一个程序员,去参加一个技术盛会。AI大潮下,五颜六色,各种不确定。 背景 AI对各行各业的冲击越来越大,身处职场的我也能清晰的感受到。 我所在的行业为全球客服外包行业。 业务模式为: 为国际跨境公司提供不同地区不同语言的客服外包解决方案,除了人力,还有软件系统。 软件系统主要是提供了客服跟客人的渠道沟通和工单管理,内部管理跟甲方的合同对接,绩效评估,BI数据透视。 客服跟客人

机器学习之监督学习(三)神经网络

机器学习之监督学习(三)神经网络基础 0. 文章传送1. 深度学习 Deep Learning深度学习的关键特点深度学习VS传统机器学习 2. 生物神经网络 Biological Neural Network3. 神经网络模型基本结构模块一:TensorFlow搭建神经网络 4. 反向传播梯度下降 Back Propagation Gradient Descent模块二:激活函数 activ

Clion不识别C代码或者无法跳转C语言项目怎么办?

如果是中文会显示: 此时只需要右击项目,或者你的源代码目录,将这个项目或者源码目录标记为项目源和头文件即可。 英文如下:

NC 把数字翻译成字符串

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 描述 有一种将字母编码成数字的方式:‘a’->1, ‘b->2’, … , ‘z->26’。 现在给一串数字,返回有多少种可能的译码结果 import java.u

34465A-61/2 数字万用表(六位半)

34465A-61/2 数字万用表(六位半) 文章目录 34465A-61/2 数字万用表(六位半)前言一、测DC/AC电压二、测DC/AC电流四、测电阻五、测电容六、测二极管七、保存截图流程 前言 1、6位半数字万用表通常具有200,000个计数器,可以显示最大为199999的数值。相比普通数字万用表,6位半万用表具有更高的测量分辨率和更高的测量准确度,适用于精度比较高的测

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04