罗斯基白话:TensorFlow + 实战系列(五)实战MNIST

2023-10-15 12:20

本文主要是介绍罗斯基白话:TensorFlow + 实战系列(五)实战MNIST,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

白话TensorFlow +实战系列(五)
实战MNIST

 

       这篇文章主要用全连接神经网络来实现MNIST手写数字识别的问题。首先介绍下MNIST数据集。

       1)MNIST数据集

       MNIST数据集是一个非常有名的手写数字识别数据集,它包含了60000张图片作为训练集,10000张图片为测试集,每张图为一个手写的0~9数字。如图:




其中每张图的大小均为28*28,这里大小指的的是像素。例如数字1所对应的像素矩阵为:




而我们要做的就是教会电脑识别每个手写数字。这个数据集非常经典,常作为学习神经网络的入门教材,一如每个程序员的第一个程序都是“helloword!”一样。

 

       2)数据处理

       数据集下载下来后有四个文件,分别为训练集图片,训练集答案,测试集图片,测试集答案。TensorFlow提供了一个类来处理MNIST数据,这个类会自动的将MNIST数据分为训练集,验证集与测试集,并且这些数据都是可以直接喂给神经网络作为输入用的。示例代码如下:



      

 其中input_data.read_data_sets会自动将数据集进行处理,one_hot = True用独热方式表示,意思是每个数字由one_hot方式表,例如数字0 = [1,0,0,0,0,0,0,0,0,0],1 = [0,1,0,0,0,0,0,0,0,0]。运行结果如下:




接下来就用一个全连接神经网络来识别数字。

 

       3)全连接神经网络

       首先定义超参数与参数,没啥好解释的,代码如下:


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_databatch_size = 100
learning_rate = 0.8
trainig_step = 30000n_input = 784
n_hidden = 500
n_labels = 10

 接着定义网络的结构,构建的网络只有一个隐藏层,隐藏层节点为500。代码如下:


def inference(x_input):with tf.variable_scope("hidden"):weights = tf.get_variable("weights", [n_input, n_hidden], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_hidden], initializer = tf.constant_initializer(0.0))hidden = tf.nn.relu(tf.matmul(x_input, weights) + biases)with tf.variable_scope("out"):weights  = tf.get_variable("weights", [n_hidden, n_labels], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_labels], initializer = tf.constant_initializer(0.0))output = tf.matmul(hidden, weights) + biasesreturn output

在输出层中,output并没有用到relu函数,因为在之后的softmax层中也是非线性激励,所以可以不用。

 

接着定义训练过程,代码如下:


def train(mnist):x = tf.placeholder("float", [None, n_input])y = tf.placeholder("float", [None, n_labels])pred = inference(x)#计算损失函数cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cross_entropy)#定义准确率计算correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)#定义验证集与测试集validate_data = {x: mnist.validation.images, y: mnist.validation.labels}test_data = {x: mnist.test.images, y: mnist.test.labels}for i in range(trainig_step):#xs,ys为每个batch_size的训练数据与对应的标签xs, ys = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cross_entropy], feed_dict={x: xs, y:ys})#每1000次训练打印一次损失值与验证准确率if i % 1000 == 0:validate_accuracy = sess.run(accuracy, feed_dict=validate_data)print("after %d training steps, the loss is %g, the validation accuracy is %g" % (i, loss, validate_accuracy))print("the training is finish!")#最终的测试准确率acc = sess.run(accuracy, feed_dict=test_data)print("the test accuarcy is:", acc)


其中每一步的函数作用可以参考我的第二篇博客: 罗斯基白话:TensorFlow+实战系列(二)从零构建传统神经网络

里面有详细的解释。


完整代码如下:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_databatch_size = 100
learning_rate = 0.8
trainig_step = 30000n_input = 784
n_hidden = 500
n_labels = 10def inference(x_input):with tf.variable_scope("hidden"):weights = tf.get_variable("weights", [n_input, n_hidden], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_hidden], initializer = tf.constant_initializer(0.0))hidden = tf.nn.relu(tf.matmul(x_input, weights) + biases)with tf.variable_scope("out"):weights  = tf.get_variable("weights", [n_hidden, n_labels], initializer = tf.random_normal_initializer(stddev = 0.1))biases = tf.get_variable("biases", [n_labels], initializer = tf.constant_initializer(0.0))output = tf.matmul(hidden, weights) + biasesreturn outputdef train(mnist):x = tf.placeholder("float", [None, n_input])y = tf.placeholder("float", [None, n_labels])pred = inference(x)#计算损失函数cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate).minimize(cross_entropy)#定义准确率计算correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)#定义验证集与测试集validate_data = {x: mnist.validation.images, y: mnist.validation.labels}test_data = {x: mnist.test.images, y: mnist.test.labels}for i in range(trainig_step):#xs,ys为每个batch_size的训练数据与对应的标签xs, ys = mnist.train.next_batch(batch_size)_, loss = sess.run([optimizer, cross_entropy], feed_dict={x: xs, y:ys})#每1000次训练打印一次损失值与验证准确率if i % 1000 == 0:validate_accuracy = sess.run(accuracy, feed_dict=validate_data)print("after %d training steps, the loss is %g, the validation accuracy is %g" % (i, loss, validate_accuracy))print("the training is finish!")#最终的测试准确率acc = sess.run(accuracy, feed_dict=test_data)print("the test accuarcy is:", acc)def main(argv = None):mnist = input_data.read_data_sets("/tensorflow/mnst_data", one_hot=True)train(mnist)if __name__ == "__main__":tf.app.run()

 

最后执行的结果如图:




可以看到最终的准确率能达到98.19%,看来效果还是很不错的。嘿嘿。

       

这篇关于罗斯基白话:TensorFlow + 实战系列(五)实战MNIST的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/217665

相关文章

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

MyBatis 动态 SQL 优化之标签的实战与技巧(常见用法)

《MyBatis动态SQL优化之标签的实战与技巧(常见用法)》本文通过详细的示例和实际应用场景,介绍了如何有效利用这些标签来优化MyBatis配置,提升开发效率,确保SQL的高效执行和安全性,感... 目录动态SQL详解一、动态SQL的核心概念1.1 什么是动态SQL?1.2 动态SQL的优点1.3 动态S

Pandas使用SQLite3实战

《Pandas使用SQLite3实战》本文主要介绍了Pandas使用SQLite3实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录1 环境准备2 从 SQLite3VlfrWQzgt 读取数据到 DataFrame基础用法:读

Python实战之屏幕录制功能的实现

《Python实战之屏幕录制功能的实现》屏幕录制,即屏幕捕获,是指将计算机屏幕上的活动记录下来,生成视频文件,本文主要为大家介绍了如何使用Python实现这一功能,希望对大家有所帮助... 目录屏幕录制原理图像捕获音频捕获编码压缩输出保存完整的屏幕录制工具高级功能实时预览增加水印多平台支持屏幕录制原理屏幕

最新Spring Security实战教程之Spring Security安全框架指南

《最新SpringSecurity实战教程之SpringSecurity安全框架指南》SpringSecurity是Spring生态系统中的核心组件,提供认证、授权和防护机制,以保护应用免受各种安... 目录前言什么是Spring Security?同类框架对比Spring Security典型应用场景传统

最新Spring Security实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)

《最新SpringSecurity实战教程之表单登录定制到处理逻辑的深度改造(最新推荐)》本章节介绍了如何通过SpringSecurity实现从配置自定义登录页面、表单登录处理逻辑的配置,并简单模拟... 目录前言改造准备开始登录页改造自定义用户名密码登陆成功失败跳转问题自定义登出前后端分离适配方案结语前言

OpenManus本地部署实战亲测有效完全免费(最新推荐)

《OpenManus本地部署实战亲测有效完全免费(最新推荐)》文章介绍了如何在本地部署OpenManus大语言模型,包括环境搭建、LLM编程接口配置和测试步骤,本文给大家讲解的非常详细,感兴趣的朋友一... 目录1.概况2.环境搭建2.1安装miniconda或者anaconda2.2 LLM编程接口配置2

基于Canvas的Html5多时区动态时钟实战代码

《基于Canvas的Html5多时区动态时钟实战代码》:本文主要介绍了如何使用Canvas在HTML5上实现一个多时区动态时钟的web展示,通过Canvas的API,可以绘制出6个不同城市的时钟,并且这些时钟可以动态转动,每个时钟上都会标注出对应的24小时制时间,详细内容请阅读本文,希望能对你有所帮助...

Spring AI与DeepSeek实战一之快速打造智能对话应用

《SpringAI与DeepSeek实战一之快速打造智能对话应用》本文详细介绍了如何通过SpringAI框架集成DeepSeek大模型,实现普通对话和流式对话功能,步骤包括申请API-KEY、项目搭... 目录一、概述二、申请DeepSeek的API-KEY三、项目搭建3.1. 开发环境要求3.2. mav