神经网络第四篇:推理处理之手写数字识别

2024-06-24 11:18

本文主要是介绍神经网络第四篇:推理处理之手写数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

到目前为止,我们已经介绍完了神经网络的基本结构,现在用一个图像识别示例对前面的知识作整体的总结。本专题知识点如下:

  • MNIST数据集
  • 图像数据转图像
  • 神经网络的推理处理
  • 批处理

  •  MNIST数据集         
mnist数据图像

MNIST数据集由0到9的数字图像构成。像素取值在0到255之间。每个图像数据都相应地标有“7”、“2”、“1”等数字标签。MNIST数据集中,训练数据有6万张,测试图像有1万张。一般先用训练数据进行学习,再用学习到的模型(参数)对测试图像进行识别分类。MNIST数据集可以从官网下载,这里我们用python获取已经下载好并做过处理的MNIST数据集的相关信息:

#数据文件:mnist.pkl,大小约54M。
#文件读取,位置
import pandas as pd
network=pd.read_pickle('F:/deep-learning with python/dataset/mnist.pkl')
type(network) #类型:字典
network.keys() #关键字['train_label', 'train_img', 'test_img', 'test_label']
#训练数据形状,(60000, 784),6万个样本,每个样本由784个数据组成(1·28·28)
network['train_img'].shape
network['train_img'][0,:].max()  #第一个样本的最大值255
network['train_img'][0,:].min()  #第一个样本的最小值0,
network['train_label'].shape #训练标签形状(60000,),由0至9组成的6万个数据
network['train_label'].max() #最大值9
network['train_label'].min() #最小值0
network['train_label']  #训练标签:0~9
network['test_img'].shape #测试数据形状(10000, 784),1万个样本
network['test_label'].shape#测试数据标签形状(10000,)
network['test_label']  #测试标签:0~9

MNIST数据集保存在mnist.pkl,读者可点击:mnist数据集及权重参数下载 进行下载,mnist数据集下载的源码如下:

# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")    with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] =  _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])    dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()

为方便训练和预测,一般以(训练数据,训练标签),(测试数据,测试标签)的形式修改数据格式。


  • 图像数据转图像                                                                                                                                       

​​​​​MNIST数据集是图像数据,每一张图的大小为28像素×28像素,像素值在0至255之间。在MNIST数据集中的形状是以一列数组的形式保存的(784个像素),因此要显示为图像时需要对数据进行相关转换。PIL是Python的图像库,可用于显示图像,这里我们用它来将MNIST图像数据转换为图像进行显示,方便大家对本主题知识的理解 。

import numpy as np
from PIL import Image
img_data=network['train_img'][1]     #训练数据中第二个样本的图像数据,(784,)
img_label=network['train_label'][1] #训练数据中第2个样本的标签为
print(img_label)                     #  0
img=img_data.reshape(28,28)#转为图像尺寸
print(img.shape)           #(28, 28)
show_img=Image.fromarray(np.uint8(img)) #转换为PIL显示图像的数据格式
show_img.show()                         #图像显示,0
第二张训练图像数据的图像显示

  • 神经网络的推理处理 

前面介绍的PIL库显示的训练数据中的第二个样本的图形是“0”,实际标签也是“0”。所谓神经网络的推理,即利用神经网络对训练数据进行学习(这里我们先直接使用学到的参数,保存在sample_weight.pkl文件中),利用学到的参数(w,b)对测试数据(test_img)进行识别,然后将识别结果与实际标签(test_label)进行比较,判断推理是否正确。原理如下图:

具体的推理处理过程为:

图像数据有784个像素(28×28),即输入层有784个神经元,推理的结果是识别数字0到9,因此输出层有10个神经元。此外,为和前面知识点相连,我们在这个神经网络中添加了2个隐藏层,第一个隐藏层有50个神经元,第二个隐藏层有100个神经元,隐藏层中神经元的个数可自己设置。由于本主题不涉及参数的学习,因此我们直接使用已学习到的参数,它保存在文件sample_weight.pkl中,根据我们设计的神经网络的结构,我们应该知道参数的结构。下面是神经网络推理的代码:

parms=pd.read_pickle('F:/deep-learning with python/ch03/sample_weight.pkl')
type(parms)  #字典
parms.keys() #['b1', 'W1', 'b2', 'W2', 'b3', 'W3']
parms['b1'].shape #(50,)
parms['W1'].shape #(784,50)
parms['b2'].shape #(100,)
parms['W2'].shape #(50,100)
parms['b3'].shape #(10,)
parms['W3'].shape #(100,10)     
def predict(parms,x):"""代码同三层神经网络的实现一样,只是将随机参数改为实际学到的参数 6激活函数在前面专题已给出"""W1,W2,W3=parms['W1'],parms['W2'],parms['W3']b1,b2,b3=parms['b1'],parms['b2'],parms['b3']a1=np.dot(x,W1)+b1z1=sigmoid(a1)  #激活函数可在以前的文章中找到a2=np.dot(z1,W2)+b2z2=sigmoid(a2)a3=np.dot(z2,W3)+b3y=softmax(a3)   #激活函数可在以前的文章中找到return ytest_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签
accuracy_cnt=0
for i in range(len(test_img)):y=predict(parms,test_img[i])p=np.argmax(y)#获取概率最高的元素的索引if p==test_label[i]:accuracy_cnt+=1
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))  #0.9352

下面我们对代码做简单介绍,首先提取测试数据和测试标签。接着用for循环逐一取出测试数据中的图像数据,然后用predict()函数进行分类,该函数输出各个标签对应的概率,比如输出[0.1,0.2,0.4…,0.03],表示“0”的概率为0.1,1的概率为0.2,9的概率为0.03。然后我们取出这个概率列表中的最大值的索引即为分类结果。最后比较神经网络预测的结果和正确解标签(test_label),对1万张图预测正确的概率作为识别精度(93.52%)。         

在机器学习领域中,一般需要考虑数据预处理,这里我们可将像素0至255可缩小到0至1的范围内(即对所有数据均除以255),然后再输入至神经网络中,这种将数据限制在某个范围内的处理称为正则化,一般情况下,预处理会改善机器学习模型。读者可比较一下图像数据正则化后神经网络的识别精度。


  • 批处理

 上面只介绍了输入一张图像数据时的处理流程。即每次向神经网络中输入一个由784个元素(原本是一个28·28的二维数组)构成的一维数组后,输出一个有10个元素的一维数组。其数据形状如下:

现在我们想predict()函数一次性打包处理100张图像。为此可把X的形状改为100×784,将100张图片打包作为输入数据。数据形状如下:

批处理数据形状

可见,输入数据的形状为100×784,输出数据的形状为100×100,这说明了输入的100张图像的推理结果被一次性输出了。例如x[0]、x[1]....x[99]和y[0]、y[1]....y[99]保存了第1、2....到100张图像的图像数据及其推理结果。这种被打包的输入数据被称为(batch),批处理主要集中在数据计算上,而不是数据读入,因此批处理可缩短时间开销。下面用代码实现如下:

test_img=network['test_img']  #测试数据
test_label=network['test_label'] #测试标签batch_size=100 #批数量
accuracy_cnt=0 #初始识别精度for i in range(0,len(test_img),batch_size):x_batch=test_img[i:i+batch_size]y_batch=predict(parms,x_batch)p=np.argmax(y_batch,axis=1)#取每列最大值accuracy_cnt+=np.sum(p==t[i:i+batch_size])
print("识别精度:"+str(float(accuracy_cnt)/len(test_img)))

批处理代码核心在于for循环语句添加了步数batch_size,输入predict()函数的参数x由以前的单条数据变为x_batch表示的100条数据,寻找二维数组中每行的最大值所在的列位置使用了参数axis=1。

至此,神经网络的基本知识已经讲解完了,后面的内容主要讲解权重参数的学习!欢迎关注微信公众号“Python生态智联”,学知识,享生活!

这篇关于神经网络第四篇:推理处理之手写数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1089999

相关文章

Python中图片与PDF识别文本(OCR)的全面指南

《Python中图片与PDF识别文本(OCR)的全面指南》在数据爆炸时代,80%的企业数据以非结构化形式存在,其中PDF和图像是最主要的载体,本文将深入探索Python中OCR技术如何将这些数字纸张转... 目录一、OCR技术核心原理二、python图像识别四大工具库1. Pytesseract - 经典O

电脑提示xlstat4.dll丢失怎么修复? xlstat4.dll文件丢失处理办法

《电脑提示xlstat4.dll丢失怎么修复?xlstat4.dll文件丢失处理办法》长时间使用电脑,大家多少都会遇到类似dll文件丢失的情况,不过,解决这一问题其实并不复杂,下面我们就来看看xls... 在Windows操作系统中,xlstat4.dll是一个重要的动态链接库文件,通常用于支持各种应用程序

SQL Server数据库死锁处理超详细攻略

《SQLServer数据库死锁处理超详细攻略》SQLServer作为主流数据库管理系统,在高并发场景下可能面临死锁问题,影响系统性能和稳定性,这篇文章主要给大家介绍了关于SQLServer数据库死... 目录一、引言二、查询 Sqlserver 中造成死锁的 SPID三、用内置函数查询执行信息1. sp_w

Java对异常的认识与异常的处理小结

《Java对异常的认识与异常的处理小结》Java程序在运行时可能出现的错误或非正常情况称为异常,下面给大家介绍Java对异常的认识与异常的处理,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参... 目录一、认识异常与异常类型。二、异常的处理三、总结 一、认识异常与异常类型。(1)简单定义-什么是

Python基于微信OCR引擎实现高效图片文字识别

《Python基于微信OCR引擎实现高效图片文字识别》这篇文章主要为大家详细介绍了一款基于微信OCR引擎的图片文字识别桌面应用开发全过程,可以实现从图片拖拽识别到文字提取,感兴趣的小伙伴可以跟随小编一... 目录一、项目概述1.1 开发背景1.2 技术选型1.3 核心优势二、功能详解2.1 核心功能模块2.

Golang 日志处理和正则处理的操作方法

《Golang日志处理和正则处理的操作方法》:本文主要介绍Golang日志处理和正则处理的操作方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录1、logx日志处理1.1、logx简介1.2、日志初始化与配置1.3、常用方法1.4、配合defer

springboot加载不到nacos配置中心的配置问题处理

《springboot加载不到nacos配置中心的配置问题处理》:本文主要介绍springboot加载不到nacos配置中心的配置问题处理,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录springboot加载不到nacos配置中心的配置两种可能Spring Boot 版本Nacos

Python验证码识别方式(使用pytesseract库)

《Python验证码识别方式(使用pytesseract库)》:本文主要介绍Python验证码识别方式(使用pytesseract库),具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录1、安装Tesseract-OCR2、在python中使用3、本地图片识别4、结合playwrigh

python web 开发之Flask中间件与请求处理钩子的最佳实践

《pythonweb开发之Flask中间件与请求处理钩子的最佳实践》Flask作为轻量级Web框架,提供了灵活的请求处理机制,中间件和请求钩子允许开发者在请求处理的不同阶段插入自定义逻辑,实现诸如... 目录Flask中间件与请求处理钩子完全指南1. 引言2. 请求处理生命周期概述3. 请求钩子详解3.1

Python处理大量Excel文件的十个技巧分享

《Python处理大量Excel文件的十个技巧分享》每天被大量Excel文件折磨的你看过来!这是一份Python程序员整理的实用技巧,不说废话,直接上干货,文章通过代码示例讲解的非常详细,需要的朋友可... 目录一、批量读取多个Excel文件二、选择性读取工作表和列三、自动调整格式和样式四、智能数据清洗五、