秃姐学AI系列之:LeNet + 代码实现

2024-08-21 00:52

本文主要是介绍秃姐学AI系列之:LeNet + 代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

LeNet 

MNIST数据集

LeNet模型图

​编辑

总结

代码实现:卷积神经网络 LeNet

LeNet(LeNet-5)由两个部分组成:卷积编码器核全连接层密集块

 检查模型


LeNet 

卷积神经网络里面最为著名的一个网络,80年代末提出来的,被广泛应用在银行、邮递行业

用于手写数字识别的一个模型

MNIST数据集

  • 50000个训练数据(在80年代末期是一个很大的数据集了,那时候内存都只有几兆)
  • 10000个测试数据
  • 图像大小 28 x 28,已经scale好了,数字都放在图像的中间,是一个灰度图
  • 10类

很长一段时间,这个数据集的知名度远高于LeNet模型 

LeNet模型图

输入一个32 x 32的image图像

-->放入一个5 x 5的卷积层,输出通道为6

-->一个2 x 2的pooling池化层,把28 x 28池化为6通道的14 x 14

-->再接一个卷积层,仍然是 5 x 5的卷积核,输出编程16通道的10 x 10

--> 一个pooling层,高宽减半,输出通道数不变

-->拉成一个向量,放到mlp全连接层,输出为120

-->第二个全连接层输出为84

-->一个高斯层(但是现在不用了,也可以理解成一个全连接层),最后输出为10(Softmax一下转换为10类的可能性输出)

总结

  • LeNet是早期成功的神经网络
  • 先使用卷积层来学习图片空间信息
  • 通过池化层来降低图片位置敏感度
  • 最后使用全连接层来转换到类别空间(10类)

代码实现:卷积神经网络 LeNet

LeNet(LeNet-5)由两个部分组成:卷积编码器核全连接层密集块

为了非线性,在每个卷积后面都加了一个Sigmoid激活函数

import torch
from torch import nn
from d2l import as d2lclass Reshape(torch.nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)net = torch.nn.Sequential(Reshape(), nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid()nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),# Flatten():第一维度保持住,后面全部拉成一个维度nn.AbgPool2d(kernel_size=2, strid=2), nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

 检查模型

因为我们是nn.Sequential定义的,所以可以每一层拿出来算一下

这里是用了__name__,也可以使用PyTorch的summary

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)

输出如下: 

 我们可以看到

  • 第一个block(卷积层+激活+池化):

        【1, 1, 28, 28】 ---> 【1, 6, 14, 14】

        做了一个图片大小减半,通道数从1扩到了6的操作,总体来说数据其实是变多了

  • 第二个block:

        【1, 6, 14, 14】---> 【1, 16, 5, 5】

        图片大小减少了大概三倍,通道数从6扩到16

  • 三层MLP:

        【1,400】--->【1, 120】--->【1, 84】--->【1, 10】

模型核心思想:我们前面讲过一个通道可以看成是一个模式,整个LeNet做的就是不断地把空间信息压缩压缩,然后把抽出来压缩的信息放在不同的通道里面,最后通过几个MLP将不同模式的通道进行融合成我们最后的输出

这篇关于秃姐学AI系列之:LeNet + 代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1091654

相关文章

Spring Security自定义身份认证的实现方法

《SpringSecurity自定义身份认证的实现方法》:本文主要介绍SpringSecurity自定义身份认证的实现方法,下面对SpringSecurity的这三种自定义身份认证进行详细讲解,... 目录1.内存身份认证(1)创建配置类(2)验证内存身份认证2.JDBC身份认证(1)数据准备 (2)配置依

利用python实现对excel文件进行加密

《利用python实现对excel文件进行加密》由于文件内容的私密性,需要对Excel文件进行加密,保护文件以免给第三方看到,本文将以Python语言为例,和大家讲讲如何对Excel文件进行加密,感兴... 目录前言方法一:使用pywin32库(仅限Windows)方法二:使用msoffcrypto-too

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

SpringBoot整合mybatisPlus实现批量插入并获取ID详解

《SpringBoot整合mybatisPlus实现批量插入并获取ID详解》这篇文章主要为大家详细介绍了SpringBoot如何整合mybatisPlus实现批量插入并获取ID,文中的示例代码讲解详细... 目录【1】saveBATch(一万条数据总耗时:2478ms)【2】集合方式foreach(一万条数

使用Python实现矢量路径的压缩、解压与可视化

《使用Python实现矢量路径的压缩、解压与可视化》在图形设计和Web开发中,矢量路径数据的高效存储与传输至关重要,本文将通过一个Python示例,展示如何将复杂的矢量路径命令序列压缩为JSON格式,... 目录引言核心功能概述1. 路径命令解析2. 路径数据压缩3. 路径数据解压4. 可视化代码实现详解1

PyQt6/PySide6中QTableView类的实现

《PyQt6/PySide6中QTableView类的实现》本文主要介绍了PyQt6/PySide6中QTableView类的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学... 目录1. 基本概念2. 创建 QTableView 实例3. QTableView 的常用属性和方法

PyQt6/PySide6中QTreeView类的实现

《PyQt6/PySide6中QTreeView类的实现》QTreeView是PyQt6或PySide6库中用于显示分层数据的控件,本文主要介绍了PyQt6/PySide6中QTreeView类的实现... 目录1. 基本概念2. 创建 QTreeView 实例3. QTreeView 的常用属性和方法属性

Android使用ImageView.ScaleType实现图片的缩放与裁剪功能

《Android使用ImageView.ScaleType实现图片的缩放与裁剪功能》ImageView是最常用的控件之一,它用于展示各种类型的图片,为了能够根据需求调整图片的显示效果,Android提... 目录什么是 ImageView.ScaleType?FIT_XYFIT_STARTFIT_CENTE

pandas中位数填充空值的实现示例

《pandas中位数填充空值的实现示例》中位数填充是一种简单而有效的方法,用于填充数据集中缺失的值,本文就来介绍一下pandas中位数填充空值的实现,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是中位数填充?为什么选择中位数填充?示例数据结果分析完整代码总结在数据分析和机器学习过程中,处理缺失数