每天进步一点点 -- pytorch学习:MNIST手写数字初体验

2024-01-17 21:59

本文主要是介绍每天进步一点点 -- pytorch学习:MNIST手写数字初体验,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

理论计算方法 -- K临近算法(OpenCV轻松入门_面向python)

        K 近邻算法的本质是将指定对象根据已知特征值分类,根据书中给出的内容,我们可以在手写数字识别上进行分析。
        假设我们要让程序识别图20-2中上方的数字(当然,你一眼就知道是“8”,但是现在要让计算机识别出来)。识别的方式是,依次计算该数字图像(即写有数字的图像)与下方数字图像的距离,与哪个数字图像的距离最近(此时k=1),就认为它与哪幅图像最像,从而确定这幅图像中的数字是多少。
  从特征值提取和数字识别两方面展开介绍:
 

 特征值提取

        步骤1:我们把数字图像划分成很多小块,如图20-3所示。该图中每个数字被分成5行4列,共计5×4=20个小块。此时,每个小块是由很多个像素点构成的。当然,也可以将每一个像素点理解为一个更小的子块
        为了叙述上的方便,将这些小块表示为B(Bigger),将B内的像素点,记为S(Smaller)。因此,待识别的数字“8”的图像可以理解为:
        ● 由5行4列,共计5×4=20个小块B构成。
        ● 每个小块B内其实是由M×N个像素(更小块S)构成的。为了描述上的方便,假设每个小块大小为10×10=100个像素。

         步骤2:计算每个小块B内,有多少个黑色的像素点。或者这样说,计算每个小块B内有多少个更小块S是黑色的。
        仍以数字“8”的图像为例,其第1行中:
        ● 第1个小块B共有0个像素点(更小块S)是黑色的,记为0。
        ● 第2个小块B共有28个像素点(更小块S)是黑色的,记为28。
        ● 第3个小块B共有10个像素点(更小块S)是黑色的,记为10。
        ● 第4个小块B共有0个像素点(更小块S)是黑色的,记为0。
以此类推,计算出数字“8”的图像中每一个小块B中有多少个像素点是黑色的,如图20-4所示。我们观察后会发现,不同的数字图像中每个小块B内黑色像素点的数量是不一样的。正是这种不同,使我们能用该数量(每个小块B内黑色像素点的个数)作为特征来表示每一个数字。

        步骤3:有时,为了处理上的方便,我们会把得到的特征值排成一行(写为数组形式),如图20-5所示。

         当然,在 Python 里完全没有必要这样做,因为 Python 可以非常方便地直接处理图 20-5 中 上方数组(array)形式的数据。这里为了说明上的方便,仍将其特征值处理为一行数字的形式。 经过上述处理,数字“8”图像的特征值变为一行数字,如图 20-6 所示

        步骤4:与数字“8”的图像类似,每个数字图像的特征值都可以用一行数字来表示。从某 种意义上来说,这一行数字类似于我们的身份证号码,一般来说,具有唯一性。 按照同样的方式,获取每个数字图像的特征值,如图 20-7 所示。 

 

数字识别

        数字识别要做的就是比较待识别图像与图像集中的哪个图像最近。这里,最近指的是二者 之间的欧氏距离最短。

        本例中为了便于说明和理解进行了简化,将原来下方的 10 个数字减少为 2 个(也即将分 类从 10 个减少为 2 个)。假设要识别的图像为图 20-8 中上方的数字“8”图像,需要判断该图 像到底属于图 20-8 中下方的数字“8” 图像的分类还是数字“7”图像的分类。

         步骤1:提取特征值,分别提取待识别图像的特征值和特征图像的特征值。
        为了说明和理解上的方便,将特征进行简化,每个数字图像只提取4个特征值(划分为2× 2=4个子块B),如图20-9所示。此时,提取到的特征值分别为:   

        步骤2:计算距离。 按照 20.1 节介绍的欧氏距离计算方法,计算待识别图像与特征图像之 间的距离。

        步骤3:识别。

         根据计算的距离,待识别的数字“8”图像与数字“8”特征图像的距离更近。所以,将待 识别的数字“8”图像识别为数字“8”特征图像所代表的数字“8”。

        上面介绍的是 K 近邻算法只考虑最近的一个邻居的情况,相当于 K 近邻中 k =1 的情况。 在实际操作中,为了提高可靠性,需要选用大量的特征值。例如,每个数字都选用不同的形态 的手写体 100 个,对于 0 ~ 9 这 10 个数字,共需要 100×10 =1000 幅特征图像。在识别数字时, 分别计算待识别的数字图像与这些特征图像之间的距离。这时,可以将 k 调整为稍大的值,例 如 k =11,然后看看其最近的 11 个邻居分属于哪些特征图像。例如,其中:

  •  有 8 个属于数字“6”特征图像。
  •  有 2 个属于数字“8”特征图像。
  •  有 1 个属于数字“9”特征图像。

通过判断,当前待识别的数字为数字“6”特征图像所代表的数字“6”

Pytorch实战  -- 神经网络

        上面的方法已经证明了,将图像一维化之后,仍然可以作为特征找到属于哪个数字,那么,Pytorch实战的理论部分就很容易理解了,课程使用的是神经网络,不是用的K临近。

数据集获取

        数据集使用的是标准MNIST数据集,相关介绍可以看官网,也可以随意百度,简单来说就是一个X矩阵,一个Y矩阵,X矩阵的一行代表一个28*28的图片矩阵按行打平之后的一行784列的矩阵,Y就是上面X对应的数字。

模型选择

        既然课程使用的神经网络,自己对神经网络也不是很懂,那就参考课程理一下思路,构建的是三层神经网络,输入为X,则为一行784列的行向量,网络输入需要784维,中间层输入参考课程使用64维,outPut需要按照数字对比,参考上面K临近时的输出,为one_shot的结果,0-9共10个数,即需要输出10维,激活函数前两层选择ReLU(容易求导),输出层的激活函数选择线性,1

此方法使用三层神经网络来做,三层网络如下:  

注:如果听原课程,一定要注意X:[1,dx]代表了X为1行784列的行向量,依次类推

代码实现

"""@Author: Administrator@DateTime: 2022/12/10 20:16Copyright (c) 2022-2025 zhaofeiTec
"""
# !/usr/bin/python
# coding: utf-8import torch
import torch.nn as nn  # 网络模型
import torch.nn.functional as F  # 方法
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 数据批处理的的图片个数
from utils import one_hot, plot_curveBATCH_SIZE = 64
# 使用训练的设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 训练的总轮次
EPOCHS = 1# 构建transform,对图像进行变换
pipeline = transforms.Compose([transforms.ToTensor(),  # 将图片转换成tensor形式transforms.Normalize((0.1307,), (0.3081,))  # 参数使用的官网的,需要自己设定,模型出现过拟合现象时,降低模型复杂度
])# 下载数据集
train_set = datasets.MNIST("data", train=True, download=True, transform=pipeline)  # 下载现有的数据集,保存到data
test_set = datasets.MNIST("data", train=False, download=True, transform=pipeline)  # 下载测试数据集
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)  #
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)class Net(nn.Module):def __init__(self):super(Net, self).__init__()# xW+bself.fc1 = nn.Linear(28 * 28, 256)  # 784维的输入,输出256维self.fc2 = nn.Linear(256, 64)  # 256维的输入,64维的输出self.fc3 = nn.Linear(64, 10)  # 64维的输入,10维的输出def forward(self, x):# x为1行784列的矩阵# h1 = relu(xW1 +b)x = F.relu(self.fc1(x))# h2 = relu(h1W2 +b)x = F.relu(self.fc2(x))# 先不加激活函数# h3 = h2W3+bx = self.fc3(x)return xnet = Net()
# net.parameters() 为net的优化对象,w1,b1,w2,b2,w3,b3 lr为学习率,momentum暂时未知
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)train_loss = []
# 循环所有训练集3次
for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):# x [64,1,28,28] 64个图片, y 为64行# 打平后 x [64,748] 64行748列x = x.view(x.size(0), 28 * 28)out = net(x)y_one_hot = one_hot(y)# loss mse为方差loss = F.mse_loss(out, y_one_hot)# 优化器optimizer.zero_grad()loss.backward()optimizer.step()# 保存训练集的损失函数train_loss.append(loss.item())plot_curve(train_loss)

util.py:

"""@Author: Administrator@DateTime: 2022/12/10 23:19Copyright (c) 2022-2025 zhaofeiTec
"""
import matplotlib.pyplot as plt
# !/usr/bin/python
# coding: utf-8
import torchdef plot_curve(data):fig = plt.figure()plt.plot(range(len(data)), data, color="blue")plt.legend(['value'], loc='upper right')plt.xlabel('step')plt.ylabel('value')plt.show()def plot_image(img, label, name):fig = plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)plt.tight_layout()plt.imshow(img[0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')plt.title("{}:{}".format(name, label[i].item()))plt.xticks()plt.yticks()plt.show()def one_hot(labels, depth=10):out = torch.zeros(labels.size(0), depth)idx = torch.LongTensor(labels).view(-1, 1)out.scatter_(dim=1, index=idx, value=1)return out

测试集验证和模型泛化性能分析

此处缺失

仅保存损失函数随训练过程的下降 

问题解决

(1)pytorch是使用Anaconda安装的python环境,使用Anaconda无法安装matplotlib,解决方法如下:

进去Anaconda的安装目录下pytorch环境的安装路径:D:\software\Anaconda3\envs\pytorch\Scripts,右击,本地打开PowerShell,执行下面的命令即可:

 .\pip.exe install matplotlib -i http://pypi.douban.com/simple --trusted-host pypi.douban.com

重启Pycharm后,已经可以使用matplotlib

这篇关于每天进步一点点 -- pytorch学习:MNIST手写数字初体验的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/617346

相关文章

捷瑞数字业绩波动性明显:关联交易不低,募资必要性遭质疑

《港湾商业观察》施子夫 5月22日,山东捷瑞数字科技股份有限公司(以下简称,捷瑞数字)及保荐机构国新证券披露第三轮问询的回复,继续推进北交所上市进程。 从2023年6月递表开始,监管层已下发三轮审核问询函,关注到捷瑞数字存在同业竞争、关联交易、募资合理性、期后业绩波动等焦点问题。公司的上市之路多少被阴影笼罩。​ 业绩波动遭问询 捷瑞数字成立于2000年,公司是一家以数字孪生驱动的工

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

《offer来了》第二章学习笔记

1.集合 Java四种集合:List、Queue、Set和Map 1.1.List:可重复 有序的Collection ArrayList: 基于数组实现,增删慢,查询快,线程不安全 Vector: 基于数组实现,增删慢,查询快,线程安全 LinkedList: 基于双向链实现,增删快,查询慢,线程不安全 1.2.Queue:队列 ArrayBlockingQueue:

19.手写Spring AOP

1.Spring AOP顶层设计 2.Spring AOP执行流程 下面是代码实现 3.在 application.properties中增加如下自定义配置: #托管的类扫描包路径#scanPackage=com.gupaoedu.vip.demotemplateRoot=layouts#切面表达式expression#pointCut=public .* com.gupaoedu

17.用300行代码手写初体验Spring V1.0版本

1.1.课程目标 1、了解看源码最有效的方式,先猜测后验证,不要一开始就去调试代码。 2、浓缩就是精华,用 300行最简洁的代码 提炼Spring的基本设计思想。 3、掌握Spring框架的基本脉络。 1.2.内容定位 1、 具有1年以上的SpringMVC使用经验。 2、 希望深入了解Spring源码的人群,对 Spring有一个整体的宏观感受。 3、 全程手写实现SpringM