MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

2024-05-04 06:12

本文主要是介绍MLP手写数字识别(2)-模型构建、训练与识别(tensorflow),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

1.MNIST的数据集下载与预处理

import tensorflow as tf
from keras.datasets import mnist
from keras.utils import to_categorical(train_x,train_y),(test_x,test_y) = mnist.load_data()
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) # 归一化
y_train,y_test = to_categorical(train_y),to_categorical(test_y) # onehot
print(X_train[:5])
print(y_train[:5])

2.搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Inputmodel = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

3.模型训练

3.1 调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
  • loss=‘categorical_crossentropy’: 损失函数设置为交叉熵损失函数,在深度学习中用交叉熵模式训练效果会比较好。
  • optimizer=‘adam’: 优化器设置为adam, 在深度学习中可以让训练更快收敛,并提高准确率。
  • metrics=[‘accuracy’]:评估模式设置为准确度评估模式。

loss参数常用的损失函数

  • binary_crossentropy: 亦称作对数损失,logloss
  • categorical_crossentropy: 交叉熵损失函数,亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为onehot形式
  • sparse_categorical_crossentropy:稀疏交叉熵损失函数。
  • kullback_leibler_divergence: 从预测值概率分布Q到真值概率分布P的信息增益,用以度量两个分布的差异
  • poisson: 即(pred-target*log(pred))的均值
  • cosine_proximity:预测值与真实标签的余弦距离平均值的相反数

优化器

  • SGD
  • RMSprop
  • Adagrad
  • Adadelta
  • Adam
  • Adamax
  • Nadam
  • TFOptimizer

评估模式

  • binary_accuracy: 对二分类问题,计算在所有预测值上的平均正确率
  • categorical_accuracy: 对多分类问题,计算在所有预测值上的平均正确率
  • sparse_categorical_accuracy:与categorical_accuracy相同,在对稀疏的目标值预测时有用
  • top_k_categorical_accuracy: 计算top-k正确率,当预测值的前K个值中存在目标类别即认为预测正确
  • sparse_top_k_categorical_accuracy: 与top_k_categorical_accuracy作用相同,但适用于稀疏情况

3.2 调用model.fit()配置训练参数,开始训练,并保存训练结果。

H = model.fit(x=X_train,y=y_train,validation_split=0.2,epochs=20,batch_size=128,verbose=1)

在这里插入图片描述

4.显示模型准确率和误差

import matplotlib.pyplot as pltdef show_train(history,train,validation):plt.plot(history.epoch, history.history[train],label=train)plt.plot(history.epoch, history.history[validation],label=validation)plt.title(train)plt.legend()plt.show()show_train(H,'loss','val_loss')
show_train(H,'accuracy','val_accuracy')

在这里插入图片描述

5.使用测试数据进行识别

import numpy as np
import matplotlib.pyplot as pltdef pred_plot_images_lables(images,labels,start_idx,num=5):# 预测res = model.predict(images[start_idx:start_idx+num])res = np.argmax(res,axis=1)# 画图fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()pred_plot_images_lables(X_test,test_y,0,5)

在这里插入图片描述

这篇关于MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/958490

相关文章

Python中构建终端应用界面利器Blessed模块的使用

《Python中构建终端应用界面利器Blessed模块的使用》Blessed库作为一个轻量级且功能强大的解决方案,开始在开发者中赢得口碑,今天,我们就一起来探索一下它是如何让终端UI开发变得轻松而高... 目录一、安装与配置:简单、快速、无障碍二、基本功能:从彩色文本到动态交互1. 显示基本内容2. 创建链

Golang使用etcd构建分布式锁的示例分享

《Golang使用etcd构建分布式锁的示例分享》在本教程中,我们将学习如何使用Go和etcd构建分布式锁系统,分布式锁系统对于管理对分布式系统中共享资源的并发访问至关重要,它有助于维护一致性,防止竞... 目录引言环境准备新建Go项目实现加锁和解锁功能测试分布式锁重构实现失败重试总结引言我们将使用Go作

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

嵌入式QT开发:构建高效智能的嵌入式系统

摘要: 本文深入探讨了嵌入式 QT 相关的各个方面。从 QT 框架的基础架构和核心概念出发,详细阐述了其在嵌入式环境中的优势与特点。文中分析了嵌入式 QT 的开发环境搭建过程,包括交叉编译工具链的配置等关键步骤。进一步探讨了嵌入式 QT 的界面设计与开发,涵盖了从基本控件的使用到复杂界面布局的构建。同时也深入研究了信号与槽机制在嵌入式系统中的应用,以及嵌入式 QT 与硬件设备的交互,包括输入输出设

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了