罗斯基白话:TensorFlow + 实战系列(五)实战MNIST

2023-10-15 12:20

本文主要是介绍罗斯基白话:TensorFlow + 实战系列(五)实战MNIST,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

白话TensorFlow +实战系列(五)
实战MNIST

 

       这篇文章主要用全连接神经网络来实现MNIST手写数字识别的问题。首先介绍下MNIST数据集。

       1)MNIST数据集

       MNIST数据集是一个非常有名的手写数字识别数据集,它包含了60000张图片作为训练集,10000张图片为测试集,每张图为一个手写的0~9数字。如图:




其中每张图的大小均为28*28,这里大小指的的是像素。例如数字1所对应的像素矩阵为:




而我们要做的就是教会电脑识别每个手写数字。这个数据集非常经典,常作为学习神经网络的入门教材,一如每个程序员的第一个程序都是“helloword!”一样。

 

       2)数据处理

       数据集下载下来后有四个文件,分别为训练集图片,训练集答案,测试集图片,测试集答案。TensorFlow提供了一个类来处理MNIST数据,这个类会自动的将MNIST数据分为训练集,验证集与测试集,并且这些数据都是可以直接喂给神经网络作为输入用的。示例代码如下:



      

 其中input_data.read_data_sets会自动将数据集进行处理,one_hot = True用独热方式表示,意思是每个数字由one_hot方式表,例如数字0 = [1,0,0,0,0,0,0,0,0,0],1 = [0,1,0,0,0,0,0,0,0,0]。运行结果如下:




接下来就用一个全连接神经网络来识别数字。

 

       3)全连接神经网络

       首先定义超参数与参数,没啥好解释的,代码如下:


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_databatch_size = 100
learning_rate = 0.8
trainig_step = 30000n_input = 784
n_hidden = 500
n_labels = 10

 接着定义网络的结构,构建的网络只有一个隐藏层,隐藏层节点为500。代码如下:


def inference(x_input):with tf.variable_scope("hidden"):weights = tf.get_variable("weights", [n_input, n_hidden], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_hidden], initializer = tf.constant_initializer(0.0))hidden = tf.nn.relu(tf.matmul(x_input, weights) + biases)with tf.variable_scope("out"):weights  = tf.get_variable("weights", [n_hidden, n_labels], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_labels], initializer = tf.constant_initializer(0.0))output = tf.matmul(hidden, weights) + biasesreturn output

在输出层中,output并没有用到relu函数,因为在之后的softmax层中也是非线性激励,所以可以不用。

 

接着定义训练过程,代码如下:


def train(mnist):x = tf.placeholder("float", [None, n_input])y = tf.placeholder("float", [None, n_labels])pred = inference(x)#计算损失函数cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cross_entropy)#定义准确率计算correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)#定义验证集与测试集validate_data = {x: mnist.validation.images, y: mnist.validation.labels}test_data = {x: mnist.test.images, y: mnist.test.labels}for i in range(trainig_step):#xs,ys为每个batch_size的训练数据与对应的标签xs, ys = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cross_entropy], feed_dict={x: xs, y:ys})#每1000次训练打印一次损失值与验证准确率if i % 1000 == 0:validate_accuracy = sess.run(accuracy, feed_dict=validate_data)print("after %d training steps, the loss is %g, the validation accuracy is %g" % (i, loss, validate_accuracy))print("the training is finish!")#最终的测试准确率acc = sess.run(accuracy, feed_dict=test_data)print("the test accuarcy is:", acc)


其中每一步的函数作用可以参考我的第二篇博客: 罗斯基白话:TensorFlow+实战系列(二)从零构建传统神经网络

里面有详细的解释。


完整代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_databatch_size = 100
learning_rate = 0.8
trainig_step = 30000n_input = 784
n_hidden = 500
n_labels = 10def inference(x_input):with tf.variable_scope("hidden"):weights = tf.get_variable("weights", [n_input, n_hidden], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_hidden], initializer = tf.constant_initializer(0.0))hidden = tf.nn.relu(tf.matmul(x_input, weights) + biases)with tf.variable_scope("out"):weights  = tf.get_variable("weights", [n_hidden, n_labels], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_labels], initializer = tf.constant_initializer(0.0))output = tf.matmul(hidden, weights) + biasesreturn outputdef train(mnist):x = tf.placeholder("float", [None, n_input])y = tf.placeholder("float", [None, n_labels])pred = inference(x)#计算损失函数cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cross_entropy)#定义准确率计算correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)#定义验证集与测试集validate_data = {x: mnist.validation.images, y: mnist.validation.labels}test_data = {x: mnist.test.images, y: mnist.test.labels}for i in range(trainig_step):#xs,ys为每个batch_size的训练数据与对应的标签xs, ys = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cross_entropy], feed_dict={x: xs, y:ys})#每1000次训练打印一次损失值与验证准确率if i % 1000 == 0:validate_accuracy = sess.run(accuracy, feed_dict=validate_data)print("after %d training steps, the loss is %g, the validation accuracy is %g" % (i, loss, validate_accuracy))print("the training is finish!")#最终的测试准确率acc = sess.run(accuracy, feed_dict=test_data)print("the test accuarcy is:", acc)def main(argv = None):mnist = input_data.read_data_sets("/tensorflow/mnst_data", one_hot=True)train(mnist)if __name__ == "__main__":tf.app.run()

 

最后执行的结果如图:




可以看到最终的准确率能达到98.19%,看来效果还是很不错的。嘿嘿。

       

这篇关于罗斯基白话:TensorFlow + 实战系列(五)实战MNIST的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/217665

相关文章

Python与DeepSeek的深度融合实战

《Python与DeepSeek的深度融合实战》Python作为最受欢迎的编程语言之一,以其简洁易读的语法、丰富的库和广泛的应用场景,成为了无数开发者的首选,而DeepSeek,作为人工智能领域的新星... 目录一、python与DeepSeek的结合优势二、模型训练1. 数据准备2. 模型架构与参数设置3

Java实战之利用POI生成Excel图表

《Java实战之利用POI生成Excel图表》ApachePOI是Java生态中处理Office文档的核心工具,这篇文章主要为大家详细介绍了如何在Excel中创建折线图,柱状图,饼图等常见图表,需要的... 目录一、环境配置与依赖管理二、数据源准备与工作表构建三、图表生成核心步骤1. 折线图(Line Ch

Java使用Tesseract-OCR实战教程

《Java使用Tesseract-OCR实战教程》本文介绍了如何在Java中使用Tesseract-OCR进行文本提取,包括Tesseract-OCR的安装、中文训练库的配置、依赖库的引入以及具体的代... 目录Java使用Tesseract-OCRTesseract-OCR安装配置中文训练库引入依赖代码实

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

在Java中使用ModelMapper简化Shapefile属性转JavaBean实战过程

《在Java中使用ModelMapper简化Shapefile属性转JavaBean实战过程》本文介绍了在Java中使用ModelMapper库简化Shapefile属性转JavaBean的过程,对比... 目录前言一、原始的处理办法1、使用Set方法来转换2、使用构造方法转换二、基于ModelMapper

Java实战之自助进行多张图片合成拼接

《Java实战之自助进行多张图片合成拼接》在当今数字化时代,图像处理技术在各个领域都发挥着至关重要的作用,本文为大家详细介绍了如何使用Java实现多张图片合成拼接,需要的可以了解下... 目录前言一、图片合成需求描述二、图片合成设计与实现1、编程语言2、基础数据准备3、图片合成流程4、图片合成实现三、总结前

nginx-rtmp-module构建流媒体直播服务器实战指南

《nginx-rtmp-module构建流媒体直播服务器实战指南》本文主要介绍了nginx-rtmp-module构建流媒体直播服务器实战指南,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. RTMP协议介绍与应用RTMP协议的原理RTMP协议的应用RTMP与现代流媒体技术的关系2

C语言小项目实战之通讯录功能

《C语言小项目实战之通讯录功能》:本文主要介绍如何设计和实现一个简单的通讯录管理系统,包括联系人信息的存储、增加、删除、查找、修改和排序等功能,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录功能介绍:添加联系人模块显示联系人模块删除联系人模块查找联系人模块修改联系人模块排序联系人模块源代码如下

Golang操作DuckDB实战案例分享

《Golang操作DuckDB实战案例分享》DuckDB是一个嵌入式SQL数据库引擎,它与众所周知的SQLite非常相似,但它是为olap风格的工作负载设计的,DuckDB支持各种数据类型和SQL特性... 目录DuckDB的主要优点环境准备初始化表和数据查询单行或多行错误处理和事务完整代码最后总结Duck

Python中的随机森林算法与实战

《Python中的随机森林算法与实战》本文详细介绍了随机森林算法,包括其原理、实现步骤、分类和回归案例,并讨论了其优点和缺点,通过面向对象编程实现了一个简单的随机森林模型,并应用于鸢尾花分类和波士顿房... 目录1、随机森林算法概述2、随机森林的原理3、实现步骤4、分类案例:使用随机森林预测鸢尾花品种4.1