GAPNet手写体数字识别

2024-04-20 03:38
文章标签 识别 数字 手写体 gapnet

本文主要是介绍GAPNet手写体数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

仅作与好友分享 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms    # datasets包含常用的数据集,transform 对图像进行预处理# training settings
batch_size = 60# MNIST Dataset,注意这里的关键工具,torch.utils, data.Dataloader,这个可以有效的读取数据,是一个得到batch的生成器
# 引入MNIST数据集通过datasets函数包进行导入
# root是数据的位置,train=True是下载训练有关的集合,download是决定下不下载数据,一斤固有数据集就download=Falsetrain_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())# Data Loader(Input Pipeline)是一个迭代器,torch.utils.data.DataLoader作用就是随机的在样本中选取数据组成一个小的batch。shuffle决定数据是否打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
# 可视化数据图像
# for i in range(5):
#     plt.figure()
#     plt.imshow(train_loader.dataset.train_data[i].numpy())# # x = torch.randn(2, 2, 2)
# firstly change the data into diresed dimension, then reshape the tensor according to what I want
# x.view(-1, 1, 4)# 理解迭代器的深层含义,torch.utils.data.DataLoader的作用理解
# for (data, target) in train_loader:
#     for i in range(4):
#         plt.figure()
#         print(target[1])
#         plt.imshow(data[i].numpy()[0])
#     breakclass LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(1, 24, 5) #pytorch文档,torch.nn.Conv2d函数参数定义self.conv2 = nn.Conv2d(24, 48, 5)self.conv3 = nn.Conv2d(in_channels=48,out_channels=32,kernel_size=5,padding=2)self.conv4 = nn.Conv2d(in_channels=32,out_channels=10,kernel_size=5,padding=2)def forward(self, x):x = F.max_pool2d(F.tanh(self.conv1(x)), (2, 2))x = F.dropout(x, p = 0.25, training=self.training)x = F.max_pool2d(F.tanh(self.conv2(x)), (2, 2))x = F.dropout(x, p = 0.25, training=self.training)x = F.tanh(self.conv3(x))x = F.dropout(x, p = 0.25, training=self.training)x = F.avg_pool2d(F.tanh(self.conv4(x)), (4, 4))x = x.view(-1, self.num_flat_features(x))return x# 定义num_flat_features函数进行尺度的变换def num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features
#     def pca(self,x):
#         print(x)
#         pca = PCA(n_components=2)   #降到1维
#         pca.fit(x)                  #训练
#         x=pca.fit_transform(x)   #降维后的数据
#         return xmodel = LeNet5()
# state_dict = torch.load('1.pth')
# model.load_state_dict(state_dict=state_dict)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1, momentum=0.9)
criterion = nn.CrossEntropyLoss()def train(epoch):model.train()   # 第一行固定,model.train是用来实现训练期间用的网络train_loss = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data), Variable(target)optimizer.zero_grad()   # tidings清零output = model(data)loss = criterion(output, target)train_loss +=lossloss.backward() # 反向传播optimizer.step()
#         if batch_idx % 10 == 0:
#             print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#             100*batch_idx / len(train_loader), loss.item()))train_loss /= len(train_loader.dataset)print('Train Epoch:{} \tAverage Loss: {:.4f}'.format(epoch,train_loss.item()))return train_loss.item()def evaluate(data_loader):model.eval()    # 测试期间用的网络loss = 0correct = 0# test数据集进行测试for data, target in data_loader:data, target = Variable(data, volatile=True), Variable(target)output = model(data)# sum up batch lossloss += criterion(output, target).item()# get the index of the max log-probabilitypred = output.data.max(1, keepdim=True)[1]  # 预测输出的结果correct += pred.eq(target.data.view_as(pred)).cpu().sum()
#     loss /= len(data_loader.dataset)return correct
Loss = []
accuracy1=[]
accuracy2=[]
for epoch in range(30):loss = train(epoch+1)Loss.append(loss)correct1=evaluate(train_loader)accuracy1.append(100. * correct1 / len(train_loader.dataset))print('\nTrain set: Accuracy: {}/{}({:.1f}%)\n'.format(correct1, len(train_loader.dataset),100. * correct1 / len(train_loader.dataset)))correct2=evaluate(test_loader)accuracy2.append(100. * correct2 / len(test_loader.dataset))print('\nTest set: Accuracy: {}/{}({:.1f}%)\n'.format(correct2, len(test_loader.dataset),100. * correct2 / len(test_loader.dataset)))#画损失函数图
plt.plot(accuracy1,label='Train Set')
plt.plot(accuracy2,label='Test Set')
plt.legend(loc=4,ncol=1)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('GAPNet Recognition Accuracy')
plt.show()  

 

这篇关于GAPNet手写体数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/919207

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

java字符串数字补齐位数详解

《java字符串数字补齐位数详解》:本文主要介绍java字符串数字补齐位数,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Java字符串数字补齐位数一、使用String.format()方法二、Apache Commons Lang库方法三、Java 11+的St

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

讯飞webapi语音识别接口调用示例代码(python)

《讯飞webapi语音识别接口调用示例代码(python)》:本文主要介绍如何使用Python3调用讯飞WebAPI语音识别接口,重点解决了在处理语音识别结果时判断是否为最后一帧的问题,通过运行代... 目录前言一、环境二、引入库三、代码实例四、运行结果五、总结前言基于python3 讯飞webAPI语音

使用Python开发一个图像标注与OCR识别工具

《使用Python开发一个图像标注与OCR识别工具》:本文主要介绍一个使用Python开发的工具,允许用户在图像上进行矩形标注,使用OCR对标注区域进行文本识别,并将结果保存为Excel文件,感兴... 目录项目简介1. 图像加载与显示2. 矩形标注3. OCR识别4. 标注的保存与加载5. 裁剪与重置图像

Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)

《Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)》本文介绍了如何使用Python和Selenium结合ddddocr库实现图片验证码的识别和点击功能,感兴趣的朋友一起看... 目录1.获取图片2.目标识别3.背景坐标识别3.1 ddddocr3.2 打码平台4.坐标点击5.图

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为

Java数字转换工具类NumberUtil的使用

《Java数字转换工具类NumberUtil的使用》NumberUtil是一个功能强大的Java工具类,用于处理数字的各种操作,包括数值运算、格式化、随机数生成和数值判断,下面就来介绍一下Number... 目录一、NumberUtil类概述二、主要功能介绍1. 数值运算2. 格式化3. 数值判断4. 随机

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推