MNIST手写字符分类-卷积

2024-06-13 13:52

本文主要是介绍MNIST手写字符分类-卷积,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

MNIST手写字符分类-卷积

文章目录

  • MNIST手写字符分类-卷积
    • 1 模型构造
    • 2 训练
    • 3 推理
    • 4 导出
    • 5 onnx测试
    • 6 opencv部署
    • 7 总结

  在上一篇中,我们介绍了如何在pytorch中使用线性层+ReLU非线性层堆叠的网络进行手写字符识别的网络构建、训练、模型保存、导出和推理测试。本篇文章中,我们将要使用卷积层进行网络构建,并完成后续的训练、保存、导出,并使用opencv在C++中推理我们的模型,将结果可视化。

1 模型构造

  在pytorch中,卷积层的使用比较方便,需要注意的是卷积层的输入通道数、输出通道数、卷积核的大小等参数。这里直接放出构建的网络结构:

import torch
from torch import nn
from torch.utils.data import DataLoader
class ZKNNNet_Conv(nn.Module):def __init__(self):super(ZKNNNet_Conv, self).__init__()self.conv_stack = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3),nn.ReLU(),nn.Conv2d(32, 64, kernel_size=3),nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(12*12*64, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):logits = self.conv_stack(x)return logits

在这里插入图片描述

从图中可以看出,该模型先堆叠了两个卷积层与ReLU单元,经过最大池化之后,展开并进行后续的全连接层训练。

2 训练

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_Conv
import os
# Download training data from open datasets.
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = ZKNNNet_Conv()
if os.path.exists("./model/model_conv.pth"):model.load_state_dict(torch.load("./model/model_conv.pth"))
model = model.to(device)
print(model)# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)# Loss function
loss_fn = nn.CrossEntropyLoss()# Train
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")# Test
def test(dataloader, model):size = len(dataloader.dataset)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= sizecorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")return correctepochs = 200
maxAcc = 0
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)currentAcc = test(test_dataloader, model)if maxAcc < currentAcc:maxAcc = currentAcctorch.save(model.state_dict(), "./model/model_conv.pth")
print("Done!")

模型的训练代码与上一篇中的线性连接训练代码是一样的。
训练过程来看,使用卷积层,在相同数据集上训练,模型收敛速度比用线性层快很多。最终精度达到97.8%。

3 推理

模型训练完成之后,推理过程与上一篇一致,这里简单放一下推理代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_Convimport matplotlib.pyplot as plt# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))# Load the trained model
model = ZKNNNet_Conv()
model.load_state_dict(torch.load("./model/model_conv.pth"))
model.to(device)
model.eval()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)# Perform inference
with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i].cpu()label = labels[i].cpu()prediction = predicted[i].cpu()plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()accuracy = 100 * correct / totalprint("Accuracy on test set: {:.2f}%".format(accuracy))

4 导出

模型导出方式与上一篇一致。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer,ZKNNNet_Conv
model_conv = ZKNNNet_Conv()
if os.path.exists('./model/model_conv.pth'):model_conv.load_state_dict(torch.load('./model/model_conv.pth'))
model_conv = model_conv.to(device)
model_conv.eval()
torch.onnx.export(model_conv,torch.randn(1,1,28,28),'./model/model_conv.onnx',verbose=True)

5 onnx测试

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasetsimport matplotlib.pyplot as pltfrom PIL import Imagesess = rt.InferenceSession("model/model_conv.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()# Download test data from open datasets.
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)with torch.no_grad():correct = 0total = 0for images, labels in test_dataloader:images = images.numpy()labels = labels.numpy()outputs = sess.run(None,{input_name:images})[0]outputs = np.array(outputs).flatten()prediction = np.argmax(outputs)# Visualize the image and its predicted resultfor i in range(len(images)):image = images[i]label = labels[i]plt.imshow(image.squeeze(), cmap='gray')plt.title(f"Label: {label}, Predicted: {prediction}")plt.show()

至此,模型已经成功的转换成onnx模型,可以用于后续各种部署环境的部署。

6 opencv部署

本例中,使用C++/opencv来尝试部署刚才训练的模型。输入为在之前的博文中提到的将MNIST测试集导出成png图片保存。

#include "opencv2/opencv.hpp"#include <iostream>
#include <filesystem>
#include <string>
#include <vector>int main(int argc, char** argv)
{if (argc != 3){std::cerr << "Usage: MNISTClassifier_onnx_opencv <onnx_model_path> <image_path>" << std::endl;return 1;}cv::dnn::Net net = cv::dnn::readNetFromONNX(argv[1]);if (net.empty()){std::cout << "Error: Failed to load ONNX file." << std::endl;return 1;}std::filesystem::path srcPath(argv[2]);for (auto& imgPath : std::filesystem::recursive_directory_iterator(srcPath)){if(!std::filesystem::is_regular_file(imgPath))continue;const cv::Mat image = cv::imread(imgPath.path().string(), cv::IMREAD_GRAYSCALE);if (image.empty()){std::cerr << "Error: Failed to read image file." << std::endl;continue;}const cv::Size size(28, 28);cv::Mat resized_image;cv::resize(image, resized_image, size);cv::Mat float_image;resized_image.convertTo(float_image, CV_32F, 1.0 / 255.0);cv::Mat input_blob = cv::dnn::blobFromImage(float_image);net.setInput(input_blob);cv::Mat output = net.forward();cv::Point classIdPoint;double confidence;cv::minMaxLoc(output.reshape(1, 1), nullptr, &confidence, nullptr, &classIdPoint);const int class_id = classIdPoint.x;std::cout << "Class ID: " << class_id << std::endl;std::cout << "Confidence: " << confidence << std::endl;cv::Mat bigImg;cv::resize(image,bigImg,cv::Size(128,128));auto parentPath = imgPath.path().parent_path();auto label = parentPath.filename().string()+std::string("<->")+std::to_string(class_id);cv::putText(bigImg, label, cv::Point(10, 20), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255), 1);cv::imshow("img",bigImg);cv::waitKey();}return 0;
}

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

7 总结

使用卷积神经网络进行MNIST手写字符识别,在模型结构无明显复杂的情况下,模型收敛速度较全连接层构建的网络收敛速度快。

按照相同的套路导出成onnx模型之后,直接通过opencv可以部署,简化深度学习算法部署的难度。

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

这篇关于MNIST手写字符分类-卷积的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1057493

相关文章

Java 字符数组转字符串的常用方法

《Java字符数组转字符串的常用方法》文章总结了在Java中将字符数组转换为字符串的几种常用方法,包括使用String构造函数、String.valueOf()方法、StringBuilder以及A... 目录1. 使用String构造函数1.1 基本转换方法1.2 注意事项2. 使用String.valu

Go语言使用Buffer实现高性能处理字节和字符

《Go语言使用Buffer实现高性能处理字节和字符》在Go中,bytes.Buffer是一个非常高效的类型,用于处理字节数据的读写操作,本文将详细介绍一下如何使用Buffer实现高性能处理字节和... 目录1. bytes.Buffer 的基本用法1.1. 创建和初始化 Buffer1.2. 使用 Writ

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

string字符会调用new分配堆内存吗

gcc的string默认大小是32个字节,字符串小于等于15直接保存在栈上,超过之后才会使用new分配。

如何将一个文件里不包含某个字符的行输出到另一个文件?

第一种: grep -v 'string' filename > newfilenamegrep -v 'string' filename >> newfilename 第二种: sed -n '/string/!'p filename > newfilenamesed -n '/string/!'p filename >> newfilename

用Pytho解决分类问题_DBSCAN聚类算法模板

一:DBSCAN聚类算法的介绍 DBSCAN(Density-Based Spatial Clustering of Applications with Noise)是一种基于密度的聚类算法,DBSCAN算法的核心思想是将具有足够高密度的区域划分为簇,并能够在具有噪声的空间数据库中发现任意形状的簇。 DBSCAN算法的主要特点包括: 1. 基于密度的聚类:DBSCAN算法通过识别被低密

PMP–一、二、三模–分类–14.敏捷–技巧–看板面板与燃尽图燃起图

文章目录 技巧一模14.敏捷--方法--看板(类似卡片)1、 [单选] 根据项目的特点,项目经理建议选择一种敏捷方法,该方法限制团队成员在任何给定时间执行的任务数。此方法还允许团队提高工作过程中问题和瓶颈的可见性。项目经理建议采用以下哪种方法? 易错14.敏捷--精益、敏捷、看板(类似卡片)--敏捷、精益和看板方法共同的重点在于交付价值、尊重人、减少浪费、透明化、适应变更以及持续改善等方面。

【python计算机视觉编程——8.图像内容分类】

python计算机视觉编程——8.图像内容分类 8.图像内容分类8.1 K邻近分类法(KNN)8.1.1 一个简单的二维示例8.1.2 用稠密SIFT作为图像特征8.1.3 图像分类:手势识别 8.2贝叶斯分类器用PCA降维 8.3 支持向量机8.3.2 再论手势识别 8.4 光学字符识别8.4.2 选取特征8.4.3 多类支持向量机8.4.4 提取单元格并识别字符8.4.5 图像校正

【Python 千题 —— 算法篇】字符统计

Python 千题持续更新中 …… 脑图地址 👉:⭐https://twilight-fanyi.gitee.io/mind-map/Python千题.html⭐ 题目背景 在编程中,对字符串的字符统计是一个常见任务。这在文本处理、数据分析、词频统计、自然语言处理等领域有广泛应用。无论是统计字母出现的频率,还是分析不同字符类型的数量,字符串字符统计都是非常有用的技术。 字符统