利用Eager写的自定义模型来训练神经网络(mnist示例)

2024-03-29 13:48

本文主要是介绍利用Eager写的自定义模型来训练神经网络(mnist示例),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

利用Eager写的自定义循环来训练神经网络(mnist示例)

文章目录

  • 利用Eager写的自定义循环来训练神经网络(mnist示例)
    • 1.为什么
    • 2.确认eager是否可以使用
    • 3.数据集选择
    • 4.数据预处理
    • 5.模型的建立
    • 6.对于每次训练的自定义
    • 7.模型的验证
    • 8.结语

1.为什么

tensorflow的keras提供了非常具体的舒服封装的神经网络和各项功能,但是越封装越影响我们对他的自定义的神经网络,所以tensorflow提供了一个 eager 自定义模式,可以自定义内层的网络循环,,这样可以在同时拥有高度封装的网络模型的情况下对训练的每一步进行调试,控制每一步的流程,也有利于我们搞懂神经网络的内部循环。

2.确认eager是否可以使用

我使用的版本是tensorflow2.4.0,默认支持eager模式,但是我们也可以用个功能来查看

tf.executing_eagerly()
True

这样就说明我们的电脑可以支持eager的运行模式,接下来我们就可以使用eager来支持对于训练每一步的自定义模式。

3.数据集选择

我们在这里选择深度学习的hello world测试集mnist,来测试,他是由60000张手写体图片组成的,每张图片大小为(28,28)灰度图,每张图片读取后可以这样展示出来

import matplotlib.pyplot as plt
import random 
ch=random.choice(range(len(train_image)))
%matplotlib inline
plt.imshow(train_image[ch])

在这里插入图片描述

4.数据预处理

对于该数据集,我们打算使用CNN卷积神经网络来解决该问题,在使用卷积神经网络需要注意使用对于图片的卷积神经网络的话。需要输入的具有四个维度分别是(batch,width,length,channel)但是我们输入的图片目前只有三个维度,最后一个维度需要我们去扩充,并且图片数据现在是无符号整型数,我们还要转换成浮点型,以及将数据归一化,代码如下:

import tensorflow as tf
from tensorflow import keras
(train_image,train_label),(test_image,test_label)=keras.datasets.mnist.load_data()#导入数据
train_images=tf.expand_dims(train_image,-1)#扩展维度
train_images=tf.cast(train_images/255,tf.float32)#归一化处理数据,然后转化数据格式
train_label=tf.cast(train_label,tf.int64)#处理标签数据
train_data=tf.data.Dataset.from_tensor_slices((train_images,train_label))#制作数据
BATCH_SIZE=32
train_data=train_data.shuffle(10000).batch(BATCH_SIZE)
#制作训练数据

5.模型的建立

​ 对于mnist这样的数据集,我们就按照正常的卷积神经网络的布局去解决就好了,唯一要注意的就是我们在这里最后不选择激活(最后返回一个十维的张量,我们可以根据哪个维度的值比较大来确定该图是哪个值,这样的值我们称为logits,同时我们也要注意如果不激活的话在计算损失的话也需要注意要声明我们最后没有激活)

model=tf.keras.Sequential()
model.add(keras.layers.Conv2D(16,(3,3),input_shape=(28,28,1),activation='relu'))
model.add(keras.layers.Conv2D(32,(3,3),activation='relu'))
model.add(keras.layers.GlobalAveragePooling2D())
model.add(keras.layers.Dense(10))

6.对于每次训练的自定义

​ 接下来才是我们eager的自定义循环最重要的部分,我们先缕清数据到底是怎么训练的,每一批次的数据输入模型,模型返回值,计算与真值的损失函数,从而来修改模型的可训练参数,如果不用eager自定义模式,那么其实这些每一步都由tensorflow自己帮我们完成,那么接下来这些由我们自己来完成,同时还要注意每一epoch训练完要输出准确率和loss那么接下来让我们来用代码实现

train_loss = tf.keras.metrics.Mean('train_loss')#meteics计算损失对象
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
#计算每一步的训练准确率#在每一批次的训练中我们规定如何去训练
def train_step(model, images, labels):#利用t来记录每个变量的变化with tf.GradientTape() as t:pred = model(images)loss_step = loss_func(labels, pred)#求解损失函数对于可训练参数的微分grads = t.gradient(loss_step, model.trainable_variables)#对于优化器来应用这样求解的梯度对于可训练参数的修改呢optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss_step)#返回loss的均值train_accuracy(labels, pred)#返回准确率均值
def train():for epoch in range(10):for (batch, (images, labels)) in enumerate(dataset):train_step(model, images, labels)print('Epoch{} loss is {}, accuracy is {}'.format(epoch,train_loss.result(),train_accuracy.result()))train_loss.reset_states()#重新设置训练参数的状态train_accuracy.reset_states()
train()#最后我们直接运行参数开始训练

输出以下内容

Epoch0 loss is 0.9328579902648926, accuracy is 0.7062000036239624

Epoch1 loss is 0.3824250400066376, accuracy is 0.8821166753768921

Epoch2 loss is 0.3067557215690613, accuracy is 0.90461665391922

Epoch3 loss is 0.27012768387794495, accuracy is 0.914900004863739

Epoch4 loss is 0.24741436541080475, accuracy is 0.9226499795913696

Epoch5 loss is 0.22932228446006775, accuracy is 0.9286666512489319

Epoch6 loss is 0.2168780267238617, accuracy is 0.9328500032424927

Epoch7 loss is 0.20622499287128448, accuracy is 0.9351166486740112

Epoch8 loss is 0.197651669383049, accuracy is 0.9382666945457458

Epoch9 loss is 0.19150032103061676, accuracy is 0.940500020980835

这样我们就完成了每一步的训练(但我在实际跑的时候,由于我不了解对于缓存的处理,我把CPU跑爆了,即使我将batch,和shuffle的数据降低,还是爆了,应该不是计算资源吧,我调用GPU跑也还是这样,希望有知道的人可以在评论区指点迷津),我们来测试一下预测值

7.模型的验证

features,label=next(iter(train_data))#获取一个批次的数据
features.shape#每一批次的数据为三十二个
prediction=model(features)
prediction.shape
tf.argmax(prediction, axis=1)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 8, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

那么我们接下来查看真值

print(label)
<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([7, 6, 8, 2, 3, 6, 3, 7, 3, 9, 2, 1, 6, 0, 8, 3, 3, 4, 3, 8, 6, 5,2, 4, 1, 1, 0, 4, 8, 5, 2, 9], dtype=int64)>

我们可以看到我们获得的模型虽然有误差但是整体来说准确率还是挺高的。

8.结语

​ 在本例中,我们完成了对于模型每次训练的自定义完成了在我们自己定义下的循环中去不断训练模型的实例,但在本次训练中对于CPU的使用还需优化,暂时没找到解决的办法希望有懂得人,或者对博客中指出的错误都可以在博客中讨论。

这篇关于利用Eager写的自定义模型来训练神经网络(mnist示例)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858756

相关文章

Java字符串操作技巧之语法、示例与应用场景分析

《Java字符串操作技巧之语法、示例与应用场景分析》在Java算法题和日常开发中,字符串处理是必备的核心技能,本文全面梳理Java中字符串的常用操作语法,结合代码示例、应用场景和避坑指南,可快速掌握字... 目录引言1. 基础操作1.1 创建字符串1.2 获取长度1.3 访问字符2. 字符串处理2.1 子字

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

前端CSS Grid 布局示例详解

《前端CSSGrid布局示例详解》CSSGrid是一种二维布局系统,可以同时控制行和列,相比Flex(一维布局),更适合用在整体页面布局或复杂模块结构中,:本文主要介绍前端CSSGri... 目录css Grid 布局详解(通俗易懂版)一、概述二、基础概念三、创建 Grid 容器四、定义网格行和列五、设置行

Node.js 数据库 CRUD 项目示例详解(完美解决方案)

《Node.js数据库CRUD项目示例详解(完美解决方案)》:本文主要介绍Node.js数据库CRUD项目示例详解(完美解决方案),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考... 目录项目结构1. 初始化项目2. 配置数据库连接 (config/db.js)3. 创建模型 (models/

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Spring LDAP目录服务的使用示例

《SpringLDAP目录服务的使用示例》本文主要介绍了SpringLDAP目录服务的使用示例... 目录引言一、Spring LDAP基础二、LdapTemplate详解三、LDAP对象映射四、基本LDAP操作4.1 查询操作4.2 添加操作4.3 修改操作4.4 删除操作五、认证与授权六、高级特性与最佳

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

CSS will-change 属性示例详解

《CSSwill-change属性示例详解》will-change是一个CSS属性,用于告诉浏览器某个元素在未来可能会发生哪些变化,本文给大家介绍CSSwill-change属性详解,感... will-change 是一个 css 属性,用于告诉浏览器某个元素在未来可能会发生哪些变化。这可以帮助浏览器优化

C++中std::distance使用方法示例

《C++中std::distance使用方法示例》std::distance是C++标准库中的一个函数,用于计算两个迭代器之间的距离,本文主要介绍了C++中std::distance使用方法示例,具... 目录语法使用方式解释示例输出:其他说明:总结std::distance&n编程bsp;是 C++ 标准

前端高级CSS用法示例详解

《前端高级CSS用法示例详解》在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交互和动态效果的关键技术之一,随着前端技术的不断发展,CSS的用法也日益丰富和高级,本文将深... 前端高级css用法在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交