paddle实现手写数字模型(一)

2024-04-09 00:04

本文主要是介绍paddle实现手写数字模型(一),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  1. 参考文档:paddle官网文档
  2. 环境:Python 3.12.2 ,pip 24.0 ,paddlepaddle 2.6.0
    python -m pip install paddlepaddle==2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 调试代码如下:
    LeNet.py
import paddle
import paddle.nn.functional as Fclass LeNet(paddle.nn.Layer):def __init__(self):super().__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2)self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1,stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

train.py


import paddle
from paddle.vision.transforms import Compose,Normalize,ToTensor
import paddle.vision.transforms as T  import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracyfrom LeNet import LeNet
from PIL import Imageprint(paddle.__version__)
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
print('下载和加载训练数据...')
train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test',transform=transform)
print('load finished')train_data0,train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0,cmap=plt.cm.binary)
#plt.show()
print('train_data0 label is: '+str(train_label_0))model = paddle.Model(LeNet())   # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
print('配置模型...')
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())
# 训练模型
print('训练模型...')
model.fit(train_dataset,epochs=2,batch_size=64,verbose=1)
# 保存模型  
model.save('./model/mnist_model')  # 默认保存模型结构和参数 #预测模型
print('预测模型...')
model.evaluate(test_dataset, batch_size=64, verbose=1)

predicted.py


import paddleimport numpy as npfrom LeNet import LeNet
from PIL import Image# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')#plt.imshow(im,cmap='gray')# print(np.array(im))im = im.resize((28, 28), Image.Resampling.LANCZOS)im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255 return im# 加载训练好的模型参数
model = LeNet()
model.load_dict(paddle.load('./model/mnist_model.pdparams'))# 设置模型为评估模式
model.eval()# 准备一个MNIST样例图像
example_image = load_image("d:/8.png")# 转换为Tensor并进行推理
with paddle.no_grad():example_tensor = paddle.to_tensor(example_image)prediction = model(example_tensor)print(prediction)# 获取预测类别
predicted_class = np.argmax(prediction.numpy(), axis=1)[0]
print(f"Predicted class: {predicted_class}")

说明:先通过执行train.py训练数据集,将模型保存在model文件夹中,
然后运行predicted.py加载训练出来的数据集,推理出d:/8.png图片的结果。
结果图片如下:
在这里插入图片描述

这篇关于paddle实现手写数字模型(一)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/886717

相关文章

SpringBoot3实现Gzip压缩优化的技术指南

《SpringBoot3实现Gzip压缩优化的技术指南》随着Web应用的用户量和数据量增加,网络带宽和页面加载速度逐渐成为瓶颈,为了减少数据传输量,提高用户体验,我们可以使用Gzip压缩HTTP响应,... 目录1、简述2、配置2.1 添加依赖2.2 配置 Gzip 压缩3、服务端应用4、前端应用4.1 N

SpringBoot实现数据库读写分离的3种方法小结

《SpringBoot实现数据库读写分离的3种方法小结》为了提高系统的读写性能和可用性,读写分离是一种经典的数据库架构模式,在SpringBoot应用中,有多种方式可以实现数据库读写分离,本文将介绍三... 目录一、数据库读写分离概述二、方案一:基于AbstractRoutingDataSource实现动态

Python FastAPI+Celery+RabbitMQ实现分布式图片水印处理系统

《PythonFastAPI+Celery+RabbitMQ实现分布式图片水印处理系统》这篇文章主要为大家详细介绍了PythonFastAPI如何结合Celery以及RabbitMQ实现简单的分布式... 实现思路FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理完整

Java枚举类实现Key-Value映射的多种实现方式

《Java枚举类实现Key-Value映射的多种实现方式》在Java开发中,枚举(Enum)是一种特殊的类,本文将详细介绍Java枚举类实现key-value映射的多种方式,有需要的小伙伴可以根据需要... 目录前言一、基础实现方式1.1 为枚举添加属性和构造方法二、http://www.cppcns.co

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

Java实现文件图片的预览和下载功能

《Java实现文件图片的预览和下载功能》这篇文章主要为大家详细介绍了如何使用Java实现文件图片的预览和下载功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... Java实现文件(图片)的预览和下载 @ApiOperation("访问文件") @GetMapping("

使用Sentinel自定义返回和实现区分来源方式

《使用Sentinel自定义返回和实现区分来源方式》:本文主要介绍使用Sentinel自定义返回和实现区分来源方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Sentinel自定义返回和实现区分来源1. 自定义错误返回2. 实现区分来源总结Sentinel自定

Java实现时间与字符串互相转换详解

《Java实现时间与字符串互相转换详解》这篇文章主要为大家详细介绍了Java中实现时间与字符串互相转换的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、日期格式化为字符串(一)使用预定义格式(二)自定义格式二、字符串解析为日期(一)解析ISO格式字符串(二)解析自定义

opencv图像处理之指纹验证的实现

《opencv图像处理之指纹验证的实现》本文主要介绍了opencv图像处理之指纹验证的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学... 目录一、简介二、具体案例实现1. 图像显示函数2. 指纹验证函数3. 主函数4、运行结果三、总结一、