基于tensorflow2、CNN的手写数字识别项目

2024-03-16 17:10

本文主要是介绍基于tensorflow2、CNN的手写数字识别项目,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

手写数字识别实战——基于tensorflow2、CNN

项目说明

该手写数字识别实战是基于tensorflow2的深度学习项目,使用tensorflow自带的MNIST手写数据集作为数据集,使用了CNN网络,最后使用模型预测手写图片。

项目环境

基础环境:python+anaconda
框架:tensorflow2

实现步骤

一、数据处理


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt"""
数据处理
"""
# 加载MNIST
mnist = tf.keras.datasets.mnist
# 加载MNIST数据集为训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化操作
x_train, x_test = x_train / 255., x_test / 255.
# 增加维度
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
# 转换为one-hot编码
y_train = np.float32(tf.keras.utils.to_categorical(y_train, num_classes=10))
y_test = np.float32(tf.keras.utils.to_categorical(y_test, num_classes=10))
# 设置批量大小
batch_size = 256
# 载入数据为dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).shuffle(batch_size * 10)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

二、搭建网络

"""
搭建网络
"""
# 输入
input_img=tf.keras.Input([28,28,1])
# 第一层卷积
conv1=tf.keras.layers.Conv2D(filters=32,kernel_size=3,padding='SAME',activation=tf.nn.relu)(input_img)
# 第二层卷积
conv2=tf.keras.layers.Conv2D(filters=64,kernel_size=3,padding='SAME',activation=tf.nn.relu)(conv1)
# 最大池化
pool=tf.keras.layers.MaxPool2D(pool_size=2,strides=2)(conv2)
# 第三层卷积
conv3=tf.keras.layers.Conv2D(filters=128,kernel_size=3,padding='SAME',activation=tf.nn.relu)(pool)
# flatten拉平
flat=tf.keras.layers.Flatten()(conv3)
# 全连接层
dense1=tf.keras.layers.Dense(units=512,activation=tf.nn.relu)(flat)
# 全连接层
dense2=tf.keras.layers.Dense(units=10,activation=tf.nn.softmax)(dense1)
# 指定模型的输入和输出
model=tf.keras.Model(inputs=input_img,outputs=dense2)
model.summary()	#查看网络结构

三、模型训练及评估

"""
模型训练及评估
"""
# 配置训练方法
model.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),loss='categorical_crossentropy',metrics=['accuracy'])
# 执行训练过程
model.fit(train_dataset,epochs=10)
# 模型评估
score=model.evaluate(test_dataset)
print('last score:',score)
# 保存模型
model.save('model.h5')

四、预测单张手写数字

import tensorflow as tf
import numpy as np
import cv2def img_show(img):          # 展示图片cv2.imshow('img',img)cv2.waitKey(0)"""
单张数字图片预测
"""# 读取图片
img=cv2.imread('./detect_img/6.png')	# 传入待预测图片
# print(img.shape)
# img_show(img)
# 转灰度图
img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
# img_show(img)
# 改变尺寸
img=cv2.resize(img,(28,28))
# img_show(img)
# 转黑底白字、归一化
img=(255-img)/255
# img_show(img)
# 转为4维
img= img.reshape((1,28,28,1))
# print(img.shape)
# 加载模型
model = tf.keras.models.load_model('model.h5')
# 预测
probabilities = model.predict(img)
print(probabilities)
prediction = np.argmax(probabilities)
prediction_values =np.max(probabilities)
print('预测:  结果:{}  概率:{:.2%}'.format(prediction,prediction_values))

最终效果

待预测手写数字图片:
在这里插入图片描述

预测结果:
在这里插入图片描述

多张预测:
在这里插入图片描述

存在问题

有些时候预测不准,尤其是0、8、6;
有大佬希望可以帮忙看看!!万分感谢!

参考资料:
MNSIT:https://zhuanlan.zhihu.com/p/36592188
代码参考:https://blog.csdn.net/woshinierye/article/details/105141631

这篇关于基于tensorflow2、CNN的手写数字识别项目的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/816143

相关文章

这15个Vue指令,让你的项目开发爽到爆

1. V-Hotkey 仓库地址: github.com/Dafrok/v-ho… Demo: 戳这里 https://dafrok.github.io/v-hotkey 安装: npm install --save v-hotkey 这个指令可以给组件绑定一个或多个快捷键。你想要通过按下 Escape 键后隐藏某个组件,按住 Control 和回车键再显示它吗?小菜一碟: <template

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

如何用Docker运行Django项目

本章教程,介绍如何用Docker创建一个Django,并运行能够访问。 一、拉取镜像 这里我们使用python3.11版本的docker镜像 docker pull python:3.11 二、运行容器 这里我们将容器内部的8080端口,映射到宿主机的80端口上。 docker run -itd --name python311 -p

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

在cscode中通过maven创建java项目

在cscode中创建java项目 可以通过博客完成maven的导入 建立maven项目 使用快捷键 Ctrl + Shift + P 建立一个 Maven 项目 1 Ctrl + Shift + P 打开输入框2 输入 "> java create"3 选择 maven4 选择 No Archetype5 输入 域名6 输入项目名称7 建立一个文件目录存放项目,文件名一般为项目名8 确定

Vue3项目开发——新闻发布管理系统(六)

文章目录 八、首页设计开发1、页面设计2、登录访问拦截实现3、用户基本信息显示①封装用户基本信息获取接口②用户基本信息存储③用户基本信息调用④用户基本信息动态渲染 4、退出功能实现①注册点击事件②添加退出功能③数据清理 5、代码下载 八、首页设计开发 登录成功后,系统就进入了首页。接下来,也就进行首页的开发了。 1、页面设计 系统页面主要分为三部分,左侧为系统的菜单栏,右侧

SpringBoot项目是如何启动

启动步骤 概念 运行main方法,初始化SpringApplication 从spring.factories读取listener ApplicationContentInitializer运行run方法读取环境变量,配置信息创建SpringApplication上下文预初始化上下文,将启动类作为配置类进行读取调用 refresh 加载 IOC容器,加载所有的自动配置类,创建容器在这个过程

Maven创建项目中的groupId, artifactId, 和 version的意思

文章目录 groupIdartifactIdversionname groupId 定义:groupId 是 Maven 项目坐标的第一个部分,它通常表示项目的组织或公司的域名反转写法。例如,如果你为公司 example.com 开发软件,groupId 可能是 com.example。作用:groupId 被用来组织和分组相关的 Maven artifacts,这样可以避免

2. 下载rknn-toolkit2项目

官网链接: https://github.com/airockchip/rknn-toolkit2 安装好git:[[1. Git的安装]] 下载项目: git clone https://github.com/airockchip/rknn-toolkit2.git 或者直接去github下载压缩文件,解压即可。