神经网络第四篇:推理处理之手写数字识别

2024-06-24 11:18

本文主要是介绍神经网络第四篇:推理处理之手写数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

到目前为止,我们已经介绍完了神经网络的基本结构,现在用一个图像识别示例对前面的知识作整体的总结。本专题知识点如下:

  • MNIST数据集
  • 图像数据转图像
  • 神经网络的推理处理
  • 批处理

  •  MNIST数据集         
mnist数据图像

MNIST数据集由0到9的数字图像构成。像素取值在0到255之间。每个图像数据都相应地标有“7”、“2”、“1”等数字标签。MNIST数据集中,训练数据有6万张,测试图像有1万张。一般先用训练数据进行学习,再用学习到的模型(参数)对测试图像进行识别分类。MNIST数据集可以从官网下载,这里我们用python获取已经下载好并做过处理的MNIST数据集的相关信息:

#数据文件:mnist.pkl,大小约54M。
#文件读取,位置
import pandas as pd
network=pd.read_pickle('F:/deep-learning with python/dataset/mnist.pkl')
type(network) #类型:字典
network.keys() #关键字['train_label', 'train_img', 'test_img', 'test_label']
#训练数据形状,(60000, 784),6万个样本,每个样本由784个数据组成(1·28·28)
network['train_img'].shape
network['train_img'][0,:].max()  #第一个样本的最大值255
network['train_img'][0,:].min()  #第一个样本的最小值0,
network['train_label'].shape #训练标签形状(60000,),由0至9组成的6万个数据
network['train_label'].max() #最大值9
network['train_label'].min() #最小值0
network['train_label']  #训练标签:0~9
network['test_img'].shape #测试数据形状(10000, 784),1万个样本
network['test_label'].shape#测试数据标签形状(10000,)
network['test_label']  #测试标签:0~9

MNIST数据集保存在mnist.pkl,读者可点击:mnist数据集及权重参数下载 进行下载,mnist数据集下载的源码如下:

# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")    with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] =  _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])    dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()

为方便训练和预测,一般以(训练数据,训练标签),(测试数据,测试标签)的形式修改数据格式。


  • 图像数据转图像                                                                                                                                       

​​​​​MNIST数据集是图像数据,每一张图的大小为28像素×28像素,像素值在0至255之间。在MNIST数据集中的形状是以一列数组的形式保存的(784个像素),因此要显示为图像时需要对数据进行相关转换。PIL是Python的图像库,可用于显示图像,这里我们用它来将MNIST图像数据转换为图像进行显示,方便大家对本主题知识的理解 。

import numpy as np
from PIL import Image
img_data=network['train_img'][1]     #训练数据中第二个样本的图像数据,(784,)
img_label=network['train_label'][1] #训练数据中第2个样本的标签为
print(img_label)                     #  0
img=img_data.reshape(28,28)#转为图像尺寸
print(img.shape)           #(28, 28)
show_img=Image.fromarray(np.uint8(img)) #转换为PIL显示图像的数据格式
show_img.show()                         #图像显示,0
第二张训练图像数据的图像显示

  • 神经网络的推理处理 

前面介绍的PIL库显示的训练数据中的第二个样本的图形是“0”,实际标签也是“0”。所谓神经网络的推理,即利用神经网络对训练数据进行学习(这里我们先直接使用学到的参数,保存在sample_weight.pkl文件中),利用学到的参数(w,b)对测试数据(test_img)进行识别,然后将识别结果与实际标签(test_label)进行比较,判断推理是否正确。原理如下图:

具体的推理处理过程为:

图像数据有784个像素(28×28),即输入层有784个神经元,推理的结果是识别数字0到9,因此输出层有10个神经元。此外,为和前面知识点相连,我们在这个神经网络中添加了2个隐藏层,第一个隐藏层有50个神经元,第二个隐藏层有100个神经元,隐藏层中神经元的个数可自己设置。由于本主题不涉及参数的学习,因此我们直接使用已学习到的参数,它保存在文件sample_weight.pkl中,根据我们设计的神经网络的结构,我们应该知道参数的结构。下面是神经网络推理的代码:

parms=pd.read_pickle('F:/deep-learning with python/ch03/sample_weight.pkl')
type(parms)  #字典
parms.keys() #['b1', 'W1', 'b2', 'W2', 'b3', 'W3']
parms['b1'].shape #(50,)
parms['W1'].shape #(784,50)
parms['b2'].shape #(100,)
parms['W2'].shape #(50,100)
parms['b3'].shape #(10,)
parms['W3'].shape #(100,10)     
def predict(parms,x):"""代码同三层神经网络的实现一样,只是将随机参数改为实际学到的参数 6激活函数在前面专题已给出"""W1,W2,W3=parms['W1'],parms['W2'],parms['W3']b1,b2,b3=parms['b1'],parms['b2'],parms['b3']a1=np.dot(x,W1)+b1z1=sigmoid(a1)  #激活函数可在以前的文章中找到a2=np.dot(z1,W2)+b2z2=sigmoid(a2)a3=np.dot(z2,W3)+b3y=softmax(a3)   #激活函数可在以前的文章中找到return ytest_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签
accuracy_cnt=0
for i in range(len(test_img)):y=predict(parms,test_img[i])p=np.argmax(y)#获取概率最高的元素的索引if p==test_label[i]:accuracy_cnt+=1
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))  #0.9352

下面我们对代码做简单介绍,首先提取测试数据和测试标签。接着用for循环逐一取出测试数据中的图像数据,然后用predict()函数进行分类,该函数输出各个标签对应的概率,比如输出[0.1,0.2,0.4…,0.03],表示“0”的概率为0.1,1的概率为0.2,9的概率为0.03。然后我们取出这个概率列表中的最大值的索引即为分类结果。最后比较神经网络预测的结果和正确解标签(test_label),对1万张图预测正确的概率作为识别精度(93.52%)。         

在机器学习领域中,一般需要考虑数据预处理,这里我们可将像素0至255可缩小到0至1的范围内(即对所有数据均除以255),然后再输入至神经网络中,这种将数据限制在某个范围内的处理称为正则化,一般情况下,预处理会改善机器学习模型。读者可比较一下图像数据正则化后神经网络的识别精度。


  • 批处理

 上面只介绍了输入一张图像数据时的处理流程。即每次向神经网络中输入一个由784个元素(原本是一个28·28的二维数组)构成的一维数组后,输出一个有10个元素的一维数组。其数据形状如下:

现在我们想predict()函数一次性打包处理100张图像。为此可把X的形状改为100×784,将100张图片打包作为输入数据。数据形状如下:

批处理数据形状

可见,输入数据的形状为100×784,输出数据的形状为100×100,这说明了输入的100张图像的推理结果被一次性输出了。例如x[0]、x[1]....x[99]和y[0]、y[1]....y[99]保存了第1、2....到100张图像的图像数据及其推理结果。这种被打包的输入数据被称为(batch),批处理主要集中在数据计算上,而不是数据读入,因此批处理可缩短时间开销。下面用代码实现如下:

test_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签batch_size=100 #批数量
accuracy_cnt=0 #初始识别精度for i in range(0,len(test_img),batch_size):x_batch=test_img[i:i+batch_size]y_batch=predict(parms,x_batch)p=np.argmax(y_batch,axis=1)#取每列最大值accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))

批处理代码核心在于for循环语句添加了步数batch_size,输入predict()函数的参数x由以前的单条数据变为x_batch表示的100条数据,寻找二维数组中每行的最大值所在的列位置使用了参数axis=1。

至此,神经网络的基本知识已经讲解完了,后面的内容主要讲解权重参数的学习!欢迎关注微信公众号“Python生态智联”,学知识,享生活!

这篇关于神经网络第四篇:推理处理之手写数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1089999

相关文章

捷瑞数字业绩波动性明显:关联交易不低,募资必要性遭质疑

《港湾商业观察》施子夫 5月22日,山东捷瑞数字科技股份有限公司(以下简称,捷瑞数字)及保荐机构国新证券披露第三轮问询的回复,继续推进北交所上市进程。 从2023年6月递表开始,监管层已下发三轮审核问询函,关注到捷瑞数字存在同业竞争、关联交易、募资合理性、期后业绩波动等焦点问题。公司的上市之路多少被阴影笼罩。​ 业绩波动遭问询 捷瑞数字成立于2000年,公司是一家以数字孪生驱动的工

19.手写Spring AOP

1.Spring AOP顶层设计 2.Spring AOP执行流程 下面是代码实现 3.在 application.properties中增加如下自定义配置: #托管的类扫描包路径#scanPackage=com.gupaoedu.vip.demotemplateRoot=layouts#切面表达式expression#pointCut=public .* com.gupaoedu

17.用300行代码手写初体验Spring V1.0版本

1.1.课程目标 1、了解看源码最有效的方式,先猜测后验证,不要一开始就去调试代码。 2、浓缩就是精华,用 300行最简洁的代码 提炼Spring的基本设计思想。 3、掌握Spring框架的基本脉络。 1.2.内容定位 1、 具有1年以上的SpringMVC使用经验。 2、 希望深入了解Spring源码的人群,对 Spring有一个整体的宏观感受。 3、 全程手写实现SpringM

大语言模型(LLMs)能够进行推理和规划吗?

大语言模型(LLMs),基本上是经过强化训练的 n-gram 模型,它们在网络规模的语言语料库(实际上,可以说是我们文明的知识库)上进行了训练,展现出了一种超乎预期的语言行为,引发了我们的广泛关注。从训练和操作的角度来看,LLMs 可以被认为是一种巨大的、非真实的记忆库,相当于为我们所有人提供了一个外部的系统 1(见图 1)。然而,它们表面上的多功能性让许多研究者好奇,这些模型是否也能在通常需要系

人工智能机器学习算法总结神经网络算法(前向及反向传播)

1.定义,意义和优缺点 定义: 神经网络算法是一种模仿人类大脑神经元之间连接方式的机器学习算法。通过多层神经元的组合和激活函数的非线性转换,神经网络能够学习数据的特征和模式,实现对复杂数据的建模和预测。(我们可以借助人类的神经元模型来更好的帮助我们理解该算法的本质,不过这里需要说明的是,虽然名字是神经网络,并且结构等等也是借鉴了神经网络,但其原型以及算法本质上还和生物层面的神经网络运行原理存在

python实现最简单循环神经网络(RNNs)

Recurrent Neural Networks(RNNs) 的模型: 上图中红色部分是输入向量。文本、单词、数据都是输入,在网络里都以向量的形式进行表示。 绿色部分是隐藏向量。是加工处理过程。 蓝色部分是输出向量。 python代码表示如下: rnn = RNN()y = rnn.step(x) # x为输入向量,y为输出向量 RNNs神经网络由神经元组成, python

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

百度OCR识别结构结构化处理视频

https://edu.csdn.net/course/detail/10506

数据时代的数字企业

1.写在前面 讨论数据治理在数字企业中的影响和必要性,并介绍数据治理的核心内容和实践方法。作者强调了数据质量、数据安全、数据隐私和数据合规等方面是数据治理的核心内容,并介绍了具体的实践措施和案例分析。企业需要重视这些方面以实现数字化转型和业务增长。 数字化转型行业小伙伴可以加入我的星球,初衷成为各位数字化转型参考库,星球内容每周更新 个人工作经验资料全部放在这里,包含数据治理、数据要

如何在Java中处理JSON数据?

如何在Java中处理JSON数据? 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将探讨在Java中如何处理JSON数据。JSON(JavaScript Object Notation)作为一种轻量级的数据交换格式,在现代应用程序中被广泛使用。Java通过多种库和API提供了处理JSON的能力,我们将深入了解其用法和最佳