利用Eager写的自定义模型来训练神经网络(mnist示例)

2024-03-29 13:48

本文主要是介绍利用Eager写的自定义模型来训练神经网络(mnist示例),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用Eager写的自定义循环来训练神经网络(mnist示例)

文章目录

  • 利用Eager写的自定义循环来训练神经网络(mnist示例)
    • 1.为什么
    • 2.确认eager是否可以使用
    • 3.数据集选择
    • 4.数据预处理
    • 5.模型的建立
    • 6.对于每次训练的自定义
    • 7.模型的验证
    • 8.结语

1.为什么

tensorflow的keras提供了非常具体的舒服封装的神经网络和各项功能,但是越封装越影响我们对他的自定义的神经网络,所以tensorflow提供了一个 eager 自定义模式,可以自定义内层的网络循环,,这样可以在同时拥有高度封装的网络模型的情况下对训练的每一步进行调试,控制每一步的流程,也有利于我们搞懂神经网络的内部循环。

2.确认eager是否可以使用

我使用的版本是tensorflow2.4.0,默认支持eager模式,但是我们也可以用个功能来查看

tf.executing_eagerly()
True

这样就说明我们的电脑可以支持eager的运行模式,接下来我们就可以使用eager来支持对于训练每一步的自定义模式。

3.数据集选择

我们在这里选择深度学习的hello world测试集mnist,来测试,他是由60000张手写体图片组成的,每张图片大小为(28,28)灰度图,每张图片读取后可以这样展示出来

import matplotlib.pyplot as plt
import random 
ch=random.choice(range(len(train_image)))
%matplotlib inline
plt.imshow(train_image[ch])

在这里插入图片描述

4.数据预处理

对于该数据集,我们打算使用CNN卷积神经网络来解决该问题,在使用卷积神经网络需要注意使用对于图片的卷积神经网络的话。需要输入的具有四个维度分别是(batch,width,length,channel)但是我们输入的图片目前只有三个维度,最后一个维度需要我们去扩充,并且图片数据现在是无符号整型数,我们还要转换成浮点型,以及将数据归一化,代码如下:

import tensorflow as tf
from tensorflow import keras
(train_image,train_label),(test_image,test_label)=keras.datasets.mnist.load_data()#导入数据
train_images=tf.expand_dims(train_image,-1)#扩展维度
train_images=tf.cast(train_images/255,tf.float32)#归一化处理数据,然后转化数据格式
train_label=tf.cast(train_label,tf.int64)#处理标签数据
train_data=tf.data.Dataset.from_tensor_slices((train_images,train_label))#制作数据
BATCH_SIZE=32
train_data=train_data.shuffle(10000).batch(BATCH_SIZE)
#制作训练数据

5.模型的建立

​ 对于mnist这样的数据集,我们就按照正常的卷积神经网络的布局去解决就好了,唯一要注意的就是我们在这里最后不选择激活(最后返回一个十维的张量,我们可以根据哪个维度的值比较大来确定该图是哪个值,这样的值我们称为logits,同时我们也要注意如果不激活的话在计算损失的话也需要注意要声明我们最后没有激活)

model=tf.keras.Sequential()
model.add(keras.layers.Conv2D(16,(3,3),input_shape=(28,28,1),activation='relu'))
model.add(keras.layers.Conv2D(32,(3,3),activation='relu'))
model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Dense(10))

6.对于每次训练的自定义

​ 接下来才是我们eager的自定义循环最重要的部分,我们先缕清数据到底是怎么训练的,每一批次的数据输入模型,模型返回值,计算与真值的损失函数,从而来修改模型的可训练参数,如果不用eager自定义模式,那么其实这些每一步都由tensorflow自己帮我们完成,那么接下来这些由我们自己来完成,同时还要注意每一epoch训练完要输出准确率和loss那么接下来让我们来用代码实现

train_loss = tf.keras.metrics.Mean('train_loss')#meteics计算损失对象
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
#计算每一步的训练准确率#在每一批次的训练中我们规定如何去训练
def train_step(model, images, labels):#利用t来记录每个变量的变化with tf.GradientTape() as t:pred = model(images)loss_step = loss_func(labels, pred)#求解损失函数对于可训练参数的微分grads = t.gradient(loss_step, model.trainable_variables)#对于优化器来应用这样求解的梯度对于可训练参数的修改呢optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss_step)#返回loss的均值train_accuracy(labels, pred)#返回准确率均值
def train():for epoch in range(10):for (batch, (images, labels)) in enumerate(dataset):train_step(model, images, labels)print('Epoch{} loss is {}, accuracy is {}'.format(epoch,train_loss.result(),train_accuracy.result()))train_loss.reset_states()#重新设置训练参数的状态train_accuracy.reset_states()
train()#最后我们直接运行参数开始训练

输出以下内容

Epoch0 loss is 0.9328579902648926, accuracy is 0.7062000036239624

Epoch1 loss is 0.3824250400066376, accuracy is 0.8821166753768921

Epoch2 loss is 0.3067557215690613, accuracy is 0.90461665391922

Epoch3 loss is 0.27012768387794495, accuracy is 0.914900004863739

Epoch4 loss is 0.24741436541080475, accuracy is 0.9226499795913696

Epoch5 loss is 0.22932228446006775, accuracy is 0.9286666512489319

Epoch6 loss is 0.2168780267238617, accuracy is 0.9328500032424927

Epoch7 loss is 0.20622499287128448, accuracy is 0.9351166486740112

Epoch8 loss is 0.197651669383049, accuracy is 0.9382666945457458

Epoch9 loss is 0.19150032103061676, accuracy is 0.940500020980835

这样我们就完成了每一步的训练(但我在实际跑的时候,由于我不了解对于缓存的处理,我把CPU跑爆了,即使我将batch,和shuffle的数据降低,还是爆了,应该不是计算资源吧,我调用GPU跑也还是这样,希望有知道的人可以在评论区指点迷津),我们来测试一下预测值

7.模型的验证

features,label=next(iter(train_data))#获取一个批次的数据
features.shape#每一批次的数据为三十二个
prediction=model(features)
prediction.shape
tf.argmax(prediction, axis=1)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 8, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

那么我们接下来查看真值

print(label)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 2, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

我们可以看到我们获得的模型虽然有误差但是整体来说准确率还是挺高的。

8.结语

​ 在本例中,我们完成了对于模型每次训练的自定义完成了在我们自己定义下的循环中去不断训练模型的实例,但在本次训练中对于CPU的使用还需优化,暂时没找到解决的办法希望有懂得人,或者对博客中指出的错误都可以在博客中讨论。

这篇关于利用Eager写的自定义模型来训练神经网络(mnist示例)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858756

相关文章

讯飞webapi语音识别接口调用示例代码(python)

《讯飞webapi语音识别接口调用示例代码(python)》:本文主要介绍如何使用Python3调用讯飞WebAPI语音识别接口,重点解决了在处理语音识别结果时判断是否为最后一帧的问题,通过运行代... 目录前言一、环境二、引入库三、代码实例四、运行结果五、总结前言基于python3 讯飞webAPI语音

MySQL中COALESCE函数示例详解

《MySQL中COALESCE函数示例详解》COALESCE是一个功能强大且常用的SQL函数,主要用来处理NULL值和实现灵活的值选择策略,能够使查询逻辑更清晰、简洁,:本文主要介绍MySQL中C... 目录语法示例1. 替换 NULL 值2. 用于字段默认值3. 多列优先级4. 结合聚合函数注意事项总结C

什么是 Java 的 CyclicBarrier(代码示例)

《什么是Java的CyclicBarrier(代码示例)》CyclicBarrier是多线程协同的利器,适合需要多次同步的场景,本文通过代码示例讲解什么是Java的CyclicBarrier,感... 你的回答(口语化,面试场景)面试官:什么是 Java 的 CyclicBarrier?你:好的,我来举个例

HTML5 data-*自定义数据属性的示例代码

《HTML5data-*自定义数据属性的示例代码》HTML5的自定义数据属性(data-*)提供了一种标准化的方法在HTML元素上存储额外信息,可以通过JavaScript访问、修改和在CSS中使用... 目录引言基本概念使用自定义数据属性1. 在 html 中定义2. 通过 JavaScript 访问3.

SpringBoot自定义注解如何解决公共字段填充问题

《SpringBoot自定义注解如何解决公共字段填充问题》本文介绍了在系统开发中,如何使用AOP切面编程实现公共字段自动填充的功能,从而简化代码,通过自定义注解和切面类,可以统一处理创建时间和修改时间... 目录1.1 问题分析1.2 实现思路1.3 代码开发1.3.1 步骤一1.3.2 步骤二1.3.3

Python使用PIL库将PNG图片转换为ICO图标的示例代码

《Python使用PIL库将PNG图片转换为ICO图标的示例代码》在软件开发和网站设计中,ICO图标是一种常用的图像格式,特别适用于应用程序图标、网页收藏夹图标等场景,本文将介绍如何使用Python的... 目录引言准备工作代码解析实践操作结果展示结语引言在软件开发和网站设计中,ICO图标是一种常用的图像

C++ Primer 标准库vector示例详解

《C++Primer标准库vector示例详解》该文章主要介绍了C++标准库中的vector类型,包括其定义、初始化、成员函数以及常见操作,文章详细解释了如何使用vector来存储和操作对象集合,... 目录3.3标准库Vector定义和初始化vector对象通列表初始化vector对象创建指定数量的元素值

MyBatis与其使用方法示例详解

《MyBatis与其使用方法示例详解》MyBatis是一个支持自定义SQL的持久层框架,通过XML文件实现SQL配置和数据映射,简化了JDBC代码的编写,本文给大家介绍MyBatis与其使用方法讲解,... 目录ORM缺优分析MyBATisMyBatis的工作流程MyBatis的基本使用环境准备MyBati

dubbo3 filter(过滤器)如何自定义过滤器

《dubbo3filter(过滤器)如何自定义过滤器》dubbo3filter(过滤器)类似于javaweb中的filter和springmvc中的intercaptor,用于在请求发送前或到达前进... 目录dubbo3 filter(过滤器)简介dubbo 过滤器运行时机自定义 filter第一种 @A

spring @EventListener 事件与监听的示例详解

《spring@EventListener事件与监听的示例详解》本文介绍了自定义Spring事件和监听器的方法,包括如何发布事件、监听事件以及如何处理异步事件,通过示例代码和日志,展示了事件的顺序... 目录1、自定义Application Event2、自定义监听3、测试4、源代码5、其他5.1 顺序执行