MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

2024-05-04 06:12

本文主要是介绍MLP手写数字识别(2)-模型构建、训练与识别(tensorflow),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

1.MNIST的数据集下载与预处理

import tensorflow as tf
from keras.datasets import mnist
from keras.utils import to_categorical(train_x,train_y),(test_x,test_y) = mnist.load_data()
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) # 归一化
y_train,y_test = to_categorical(train_y),to_categorical(test_y) # onehot
print(X_train[:5])
print(y_train[:5])

2.搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Inputmodel = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

3.模型训练

3.1 调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
  • loss=‘categorical_crossentropy’: 损失函数设置为交叉熵损失函数,在深度学习中用交叉熵模式训练效果会比较好。
  • optimizer=‘adam’: 优化器设置为adam, 在深度学习中可以让训练更快收敛,并提高准确率。
  • metrics=[‘accuracy’]:评估模式设置为准确度评估模式。

loss参数常用的损失函数

  • binary_crossentropy: 亦称作对数损失,logloss
  • categorical_crossentropy: 交叉熵损失函数,亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为onehot形式
  • sparse_categorical_crossentropy:稀疏交叉熵损失函数。
  • kullback_leibler_divergence: 从预测值概率分布Q到真值概率分布P的信息增益,用以度量两个分布的差异
  • poisson: 即(pred-target*log(pred))的均值
  • cosine_proximity:预测值与真实标签的余弦距离平均值的相反数

优化器

  • SGD
  • RMSprop
  • Adagrad
  • Adadelta
  • Adam
  • Adamax
  • Nadam
  • TFOptimizer

评估模式

  • binary_accuracy: 对二分类问题,计算在所有预测值上的平均正确率
  • categorical_accuracy: 对多分类问题,计算在所有预测值上的平均正确率
  • sparse_categorical_accuracy:与categorical_accuracy相同,在对稀疏的目标值预测时有用
  • top_k_categorical_accuracy: 计算top-k正确率,当预测值的前K个值中存在目标类别即认为预测正确
  • sparse_top_k_categorical_accuracy: 与top_k_categorical_accuracy作用相同,但适用于稀疏情况

3.2 调用model.fit()配置训练参数,开始训练,并保存训练结果。

H = model.fit(x=X_train,y=y_train,validation_split=0.2,epochs=20,batch_size=128,verbose=1)

在这里插入图片描述

4.显示模型准确率和误差

import matplotlib.pyplot as pltdef show_train(history,train,validation):plt.plot(history.epoch, history.history[train],label=train)plt.plot(history.epoch, history.history[validation],label=validation)plt.title(train)plt.legend()plt.show()show_train(H,'loss','val_loss')
show_train(H,'accuracy','val_accuracy')

在这里插入图片描述

5.使用测试数据进行识别

import numpy as np
import matplotlib.pyplot as pltdef pred_plot_images_lables(images,labels,start_idx,num=5):# 预测res = model.predict(images[start_idx:start_idx+num])res = np.argmax(res,axis=1)# 画图fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()pred_plot_images_lables(X_test,test_y,0,5)

在这里插入图片描述

这篇关于MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/958490

相关文章

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

Java数字转换工具类NumberUtil的使用

《Java数字转换工具类NumberUtil的使用》NumberUtil是一个功能强大的Java工具类,用于处理数字的各种操作,包括数值运算、格式化、随机数生成和数值判断,下面就来介绍一下Number... 目录一、NumberUtil类概述二、主要功能介绍1. 数值运算2. 格式化3. 数值判断4. 随机

如何在本地部署 DeepSeek Janus Pro 文生图大模型

《如何在本地部署DeepSeekJanusPro文生图大模型》DeepSeekJanusPro模型在本地成功部署,支持图片理解和文生图功能,通过Gradio界面进行交互,展示了其强大的多模态处... 目录什么是 Janus Pro1. 安装 conda2. 创建 python 虚拟环境3. 克隆 janus

本地私有化部署DeepSeek模型的详细教程

《本地私有化部署DeepSeek模型的详细教程》DeepSeek模型是一种强大的语言模型,本地私有化部署可以让用户在自己的环境中安全、高效地使用该模型,避免数据传输到外部带来的安全风险,同时也能根据自... 目录一、引言二、环境准备(一)硬件要求(二)软件要求(三)创建虚拟环境三、安装依赖库四、获取 Dee

nginx-rtmp-module构建流媒体直播服务器实战指南

《nginx-rtmp-module构建流媒体直播服务器实战指南》本文主要介绍了nginx-rtmp-module构建流媒体直播服务器实战指南,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有... 目录1. RTMP协议介绍与应用RTMP协议的原理RTMP协议的应用RTMP与现代流媒体技术的关系2

DeepSeek模型本地部署的详细教程

《DeepSeek模型本地部署的详细教程》DeepSeek作为一款开源且性能强大的大语言模型,提供了灵活的本地部署方案,让用户能够在本地环境中高效运行模型,同时保护数据隐私,在本地成功部署DeepSe... 目录一、环境准备(一)硬件需求(二)软件依赖二、安装Ollama三、下载并部署DeepSeek模型选

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe