利用Eager写的自定义模型来训练神经网络(mnist示例)

2024-03-29 13:48

本文主要是介绍利用Eager写的自定义模型来训练神经网络(mnist示例),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用Eager写的自定义循环来训练神经网络(mnist示例)

文章目录

  • 利用Eager写的自定义循环来训练神经网络(mnist示例)
    • 1.为什么
    • 2.确认eager是否可以使用
    • 3.数据集选择
    • 4.数据预处理
    • 5.模型的建立
    • 6.对于每次训练的自定义
    • 7.模型的验证
    • 8.结语

1.为什么

tensorflow的keras提供了非常具体的舒服封装的神经网络和各项功能,但是越封装越影响我们对他的自定义的神经网络,所以tensorflow提供了一个 eager 自定义模式,可以自定义内层的网络循环,,这样可以在同时拥有高度封装的网络模型的情况下对训练的每一步进行调试,控制每一步的流程,也有利于我们搞懂神经网络的内部循环。

2.确认eager是否可以使用

我使用的版本是tensorflow2.4.0,默认支持eager模式,但是我们也可以用个功能来查看

tf.executing_eagerly()
True

这样就说明我们的电脑可以支持eager的运行模式,接下来我们就可以使用eager来支持对于训练每一步的自定义模式。

3.数据集选择

我们在这里选择深度学习的hello world测试集mnist,来测试,他是由60000张手写体图片组成的,每张图片大小为(28,28)灰度图,每张图片读取后可以这样展示出来

import matplotlib.pyplot as plt
import random 
ch=random.choice(range(len(train_image)))
%matplotlib inline
plt.imshow(train_image[ch])

在这里插入图片描述

4.数据预处理

对于该数据集,我们打算使用CNN卷积神经网络来解决该问题,在使用卷积神经网络需要注意使用对于图片的卷积神经网络的话。需要输入的具有四个维度分别是(batch,width,length,channel)但是我们输入的图片目前只有三个维度,最后一个维度需要我们去扩充,并且图片数据现在是无符号整型数,我们还要转换成浮点型,以及将数据归一化,代码如下:

import tensorflow as tf
from tensorflow import keras
(train_image,train_label),(test_image,test_label)=keras.datasets.mnist.load_data()#导入数据
train_images=tf.expand_dims(train_image,-1)#扩展维度
train_images=tf.cast(train_images/255,tf.float32)#归一化处理数据,然后转化数据格式
train_label=tf.cast(train_label,tf.int64)#处理标签数据
train_data=tf.data.Dataset.from_tensor_slices((train_images,train_label))#制作数据
BATCH_SIZE=32
train_data=train_data.shuffle(10000).batch(BATCH_SIZE)
#制作训练数据

5.模型的建立

​ 对于mnist这样的数据集,我们就按照正常的卷积神经网络的布局去解决就好了,唯一要注意的就是我们在这里最后不选择激活(最后返回一个十维的张量,我们可以根据哪个维度的值比较大来确定该图是哪个值,这样的值我们称为logits,同时我们也要注意如果不激活的话在计算损失的话也需要注意要声明我们最后没有激活)

model=tf.keras.Sequential()
model.add(keras.layers.Conv2D(16,(3,3),input_shape=(28,28,1),activation='relu'))
model.add(keras.layers.Conv2D(32,(3,3),activation='relu'))
model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Dense(10))

6.对于每次训练的自定义

​ 接下来才是我们eager的自定义循环最重要的部分,我们先缕清数据到底是怎么训练的,每一批次的数据输入模型,模型返回值,计算与真值的损失函数,从而来修改模型的可训练参数,如果不用eager自定义模式,那么其实这些每一步都由tensorflow自己帮我们完成,那么接下来这些由我们自己来完成,同时还要注意每一epoch训练完要输出准确率和loss那么接下来让我们来用代码实现

train_loss = tf.keras.metrics.Mean('train_loss')#meteics计算损失对象
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
#计算每一步的训练准确率#在每一批次的训练中我们规定如何去训练
def train_step(model, images, labels):#利用t来记录每个变量的变化with tf.GradientTape() as t:pred = model(images)loss_step = loss_func(labels, pred)#求解损失函数对于可训练参数的微分grads = t.gradient(loss_step, model.trainable_variables)#对于优化器来应用这样求解的梯度对于可训练参数的修改呢optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss_step)#返回loss的均值train_accuracy(labels, pred)#返回准确率均值
def train():for epoch in range(10):for (batch, (images, labels)) in enumerate(dataset):train_step(model, images, labels)print('Epoch{} loss is {}, accuracy is {}'.format(epoch,train_loss.result(),train_accuracy.result()))train_loss.reset_states()#重新设置训练参数的状态train_accuracy.reset_states()
train()#最后我们直接运行参数开始训练

输出以下内容

Epoch0 loss is 0.9328579902648926, accuracy is 0.7062000036239624

Epoch1 loss is 0.3824250400066376, accuracy is 0.8821166753768921

Epoch2 loss is 0.3067557215690613, accuracy is 0.90461665391922

Epoch3 loss is 0.27012768387794495, accuracy is 0.914900004863739

Epoch4 loss is 0.24741436541080475, accuracy is 0.9226499795913696

Epoch5 loss is 0.22932228446006775, accuracy is 0.9286666512489319

Epoch6 loss is 0.2168780267238617, accuracy is 0.9328500032424927

Epoch7 loss is 0.20622499287128448, accuracy is 0.9351166486740112

Epoch8 loss is 0.197651669383049, accuracy is 0.9382666945457458

Epoch9 loss is 0.19150032103061676, accuracy is 0.940500020980835

这样我们就完成了每一步的训练(但我在实际跑的时候,由于我不了解对于缓存的处理,我把CPU跑爆了,即使我将batch,和shuffle的数据降低,还是爆了,应该不是计算资源吧,我调用GPU跑也还是这样,希望有知道的人可以在评论区指点迷津),我们来测试一下预测值

7.模型的验证

features,label=next(iter(train_data))#获取一个批次的数据
features.shape#每一批次的数据为三十二个
prediction=model(features)
prediction.shape
tf.argmax(prediction, axis=1)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 8, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

那么我们接下来查看真值

print(label)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 2, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

我们可以看到我们获得的模型虽然有误差但是整体来说准确率还是挺高的。

8.结语

​ 在本例中,我们完成了对于模型每次训练的自定义完成了在我们自己定义下的循环中去不断训练模型的实例,但在本次训练中对于CPU的使用还需优化,暂时没找到解决的办法希望有懂得人,或者对博客中指出的错误都可以在博客中讨论。

这篇关于利用Eager写的自定义模型来训练神经网络(mnist示例)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858756

相关文章

SpringBoot基于MyBatis-Plus实现Lambda Query查询的示例代码

《SpringBoot基于MyBatis-Plus实现LambdaQuery查询的示例代码》MyBatis-Plus是MyBatis的增强工具,简化了数据库操作,并提高了开发效率,它提供了多种查询方... 目录引言基础环境配置依赖配置(Maven)application.yml 配置表结构设计demo_st

SpringCloud集成AlloyDB的示例代码

《SpringCloud集成AlloyDB的示例代码》AlloyDB是GoogleCloud提供的一种高度可扩展、强性能的关系型数据库服务,它兼容PostgreSQL,并提供了更快的查询性能... 目录1.AlloyDBjavascript是什么?AlloyDB 的工作原理2.搭建测试环境3.代码工程1.

Java中ArrayList的8种浅拷贝方式示例代码

《Java中ArrayList的8种浅拷贝方式示例代码》:本文主要介绍Java中ArrayList的8种浅拷贝方式的相关资料,讲解了Java中ArrayList的浅拷贝概念,并详细分享了八种实现浅... 目录引言什么是浅拷贝?ArrayList 浅拷贝的重要性方法一:使用构造函数方法二:使用 addAll(

Golang使用etcd构建分布式锁的示例分享

《Golang使用etcd构建分布式锁的示例分享》在本教程中,我们将学习如何使用Go和etcd构建分布式锁系统,分布式锁系统对于管理对分布式系统中共享资源的并发访问至关重要,它有助于维护一致性,防止竞... 目录引言环境准备新建Go项目实现加锁和解锁功能测试分布式锁重构实现失败重试总结引言我们将使用Go作

JAVA利用顺序表实现“杨辉三角”的思路及代码示例

《JAVA利用顺序表实现“杨辉三角”的思路及代码示例》杨辉三角形是中国古代数学的杰出研究成果之一,是我国北宋数学家贾宪于1050年首先发现并使用的,:本文主要介绍JAVA利用顺序表实现杨辉三角的思... 目录一:“杨辉三角”题目链接二:题解代码:三:题解思路:总结一:“杨辉三角”题目链接题目链接:点击这里

SpringBoot使用注解集成Redis缓存的示例代码

《SpringBoot使用注解集成Redis缓存的示例代码》:本文主要介绍在SpringBoot中使用注解集成Redis缓存的步骤,包括添加依赖、创建相关配置类、需要缓存数据的类(Tes... 目录一、创建 Caching 配置类二、创建需要缓存数据的类三、测试方法Spring Boot 熟悉后,集成一个外

Springboot使用RabbitMQ实现关闭超时订单(示例详解)

《Springboot使用RabbitMQ实现关闭超时订单(示例详解)》介绍了如何在SpringBoot项目中使用RabbitMQ实现订单的延时处理和超时关闭,通过配置RabbitMQ的交换机、队列和... 目录1.maven中引入rabbitmq的依赖:2.application.yml中进行rabbit

Python绘制土地利用和土地覆盖类型图示例详解

《Python绘制土地利用和土地覆盖类型图示例详解》本文介绍了如何使用Python绘制土地利用和土地覆盖类型图,并提供了详细的代码示例,通过安装所需的库,准备地理数据,使用geopandas和matp... 目录一、所需库的安装二、数据准备三、绘制土地利用和土地覆盖类型图四、代码解释五、其他可视化形式1.

opencv实现像素统计的示例代码

《opencv实现像素统计的示例代码》本文介绍了OpenCV中统计图像像素信息的常用方法和函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 统计像素值的基本信息2. 统计像素值的直方图3. 统计像素值的总和4. 统计非零像素的数量

Python使用asyncio实现异步操作的示例

《Python使用asyncio实现异步操作的示例》本文主要介绍了Python使用asyncio实现异步操作的示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋... 目录1. 基础概念2. 实现异步 I/O 的步骤2.1 定义异步函数2.2 使用 await 等待异