MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

2024-05-04 06:12

本文主要是介绍MLP手写数字识别(2)-模型构建、训练与识别(tensorflow),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

1.MNIST的数据集下载与预处理

import tensorflow as tf
from keras.datasets import mnist
from keras.utils import to_categorical(train_x,train_y),(test_x,test_y) = mnist.load_data()
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) # 归一化
y_train,y_test = to_categorical(train_y),to_categorical(test_y) # onehot
print(X_train[:5])
print(y_train[:5])

2.搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Inputmodel = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

3.模型训练

3.1 调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
  • loss=‘categorical_crossentropy’: 损失函数设置为交叉熵损失函数,在深度学习中用交叉熵模式训练效果会比较好。
  • optimizer=‘adam’: 优化器设置为adam, 在深度学习中可以让训练更快收敛,并提高准确率。
  • metrics=[‘accuracy’]:评估模式设置为准确度评估模式。

loss参数常用的损失函数

  • binary_crossentropy: 亦称作对数损失,logloss
  • categorical_crossentropy: 交叉熵损失函数,亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为onehot形式
  • sparse_categorical_crossentropy:稀疏交叉熵损失函数。
  • kullback_leibler_divergence: 从预测值概率分布Q到真值概率分布P的信息增益,用以度量两个分布的差异
  • poisson: 即(pred-target*log(pred))的均值
  • cosine_proximity:预测值与真实标签的余弦距离平均值的相反数

优化器

  • SGD
  • RMSprop
  • Adagrad
  • Adadelta
  • Adam
  • Adamax
  • Nadam
  • TFOptimizer

评估模式

  • binary_accuracy: 对二分类问题,计算在所有预测值上的平均正确率
  • categorical_accuracy: 对多分类问题,计算在所有预测值上的平均正确率
  • sparse_categorical_accuracy:与categorical_accuracy相同,在对稀疏的目标值预测时有用
  • top_k_categorical_accuracy: 计算top-k正确率,当预测值的前K个值中存在目标类别即认为预测正确
  • sparse_top_k_categorical_accuracy: 与top_k_categorical_accuracy作用相同,但适用于稀疏情况

3.2 调用model.fit()配置训练参数,开始训练,并保存训练结果。

H = model.fit(x=X_train,y=y_train,validation_split=0.2,epochs=20,batch_size=128,verbose=1)

在这里插入图片描述

4.显示模型准确率和误差

import matplotlib.pyplot as pltdef show_train(history,train,validation):plt.plot(history.epoch, history.history[train],label=train)plt.plot(history.epoch, history.history[validation],label=validation)plt.title(train)plt.legend()plt.show()show_train(H,'loss','val_loss')
show_train(H,'accuracy','val_accuracy')

在这里插入图片描述

5.使用测试数据进行识别

import numpy as np
import matplotlib.pyplot as pltdef pred_plot_images_lables(images,labels,start_idx,num=5):# 预测res = model.predict(images[start_idx:start_idx+num])res = np.argmax(res,axis=1)# 画图fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()pred_plot_images_lables(X_test,test_y,0,5)

在这里插入图片描述

这篇关于MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/958490

相关文章

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

一文详解如何从零构建Spring Boot Starter并实现整合

《一文详解如何从零构建SpringBootStarter并实现整合》SpringBoot是一个开源的Java基础框架,用于创建独立、生产级的基于Spring框架的应用程序,:本文主要介绍如何从... 目录一、Spring Boot Starter的核心价值二、Starter项目创建全流程2.1 项目初始化(

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

java字符串数字补齐位数详解

《java字符串数字补齐位数详解》:本文主要介绍java字符串数字补齐位数,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Java字符串数字补齐位数一、使用String.format()方法二、Apache Commons Lang库方法三、Java 11+的St

使用Python和python-pptx构建Markdown到PowerPoint转换器

《使用Python和python-pptx构建Markdown到PowerPoint转换器》在这篇博客中,我们将深入分析一个使用Python开发的应用程序,该程序可以将Markdown文件转换为Pow... 目录引言应用概述代码结构与分析1. 类定义与初始化2. 事件处理3. Markdown 处理4. 转

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

讯飞webapi语音识别接口调用示例代码(python)

《讯飞webapi语音识别接口调用示例代码(python)》:本文主要介绍如何使用Python3调用讯飞WebAPI语音识别接口,重点解决了在处理语音识别结果时判断是否为最后一帧的问题,通过运行代... 目录前言一、环境二、引入库三、代码实例四、运行结果五、总结前言基于python3 讯飞webAPI语音

Java使用Mail构建邮件功能的完整指南

《Java使用Mail构建邮件功能的完整指南》JavaMailAPI是一个功能强大的工具,它可以帮助开发者轻松实现邮件的发送与接收功能,本文将介绍如何使用JavaMail发送和接收邮件,希望对大家有所... 目录1、简述2、主要特点3、发送样例3.1 发送纯文本邮件3.2 发送 html 邮件3.3 发送带