GAPNet手写体数字识别

2024-04-20 03:38
文章标签 识别 数字 手写体 gapnet

本文主要是介绍GAPNet手写体数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

仅作与好友分享 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms    # datasets包含常用的数据集,transform 对图像进行预处理# training settings
batch_size = 60# MNIST Dataset,注意这里的关键工具,torch.utils, data.Dataloader,这个可以有效的读取数据,是一个得到batch的生成器
# 引入MNIST数据集通过datasets函数包进行导入
# root是数据的位置,train=True是下载训练有关的集合,download是决定下不下载数据,一斤固有数据集就download=Falsetrain_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())# Data Loader(Input Pipeline)是一个迭代器,torch.utils.data.DataLoader作用就是随机的在样本中选取数据组成一个小的batch。shuffle决定数据是否打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
# 可视化数据图像
# for i in range(5):
#     plt.figure()
#     plt.imshow(train_loader.dataset.train_data[i].numpy())# # x = torch.randn(2, 2, 2)
# firstly change the data into diresed dimension, then reshape the tensor according to what I want
# x.view(-1, 1, 4)# 理解迭代器的深层含义,torch.utils.data.DataLoader的作用理解
# for (data, target) in train_loader:
#     for i in range(4):
#         plt.figure()
#         print(target[1])
#         plt.imshow(data[i].numpy()[0])
#     breakclass LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(1, 24, 5) #pytorch文档,torch.nn.Conv2d函数参数定义self.conv2 = nn.Conv2d(24, 48, 5)self.conv3 = nn.Conv2d(in_channels=48,out_channels=32,kernel_size=5,padding=2)self.conv4 = nn.Conv2d(in_channels=32,out_channels=10,kernel_size=5,padding=2)def forward(self, x):x = F.max_pool2d(F.tanh(self.conv1(x)), (2, 2))x = F.dropout(x, p = 0.25, training=self.training)x = F.max_pool2d(F.tanh(self.conv2(x)), (2, 2))x = F.dropout(x, p = 0.25, training=self.training)x = F.tanh(self.conv3(x))x = F.dropout(x, p = 0.25, training=self.training)x = F.avg_pool2d(F.tanh(self.conv4(x)), (4, 4))x = x.view(-1, self.num_flat_features(x))return x# 定义num_flat_features函数进行尺度的变换def num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features
#     def pca(self,x):
#         print(x)
#         pca = PCA(n_components=2)   #降到1维
#         pca.fit(x)                  #训练
#         x=pca.fit_transform(x)   #降维后的数据
#         return xmodel = LeNet5()
# state_dict = torch.load('1.pth')
# model.load_state_dict(state_dict=state_dict)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1, momentum=0.9)
criterion = nn.CrossEntropyLoss()def train(epoch):model.train()   # 第一行固定,model.train是用来实现训练期间用的网络train_loss = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data), Variable(target)optimizer.zero_grad()   # tidings清零output = model(data)loss = criterion(output, target)train_loss +=lossloss.backward() # 反向传播optimizer.step()
#         if batch_idx % 10 == 0:
#             print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(data), len(train_loader.dataset),
#             100*batch_idx / len(train_loader), loss.item()))train_loss /= len(train_loader.dataset)print('Train Epoch:{} \tAverage Loss: {:.4f}'.format(epoch,train_loss.item()))return train_loss.item()def evaluate(data_loader):model.eval()    # 测试期间用的网络loss = 0correct = 0# test数据集进行测试for data, target in data_loader:data, target = Variable(data, volatile=True), Variable(target)output = model(data)# sum up batch lossloss += criterion(output, target).item()# get the index of the max log-probabilitypred = output.data.max(1, keepdim=True)[1]  # 预测输出的结果correct += pred.eq(target.data.view_as(pred)).cpu().sum()
#     loss /= len(data_loader.dataset)return correct
Loss = []
accuracy1=[]
accuracy2=[]
for epoch in range(30):loss = train(epoch+1)Loss.append(loss)correct1=evaluate(train_loader)accuracy1.append(100. * correct1 / len(train_loader.dataset))print('\nTrain set: Accuracy: {}/{}({:.1f}%)\n'.format(correct1, len(train_loader.dataset),100. * correct1 / len(train_loader.dataset)))correct2=evaluate(test_loader)accuracy2.append(100. * correct2 / len(test_loader.dataset))print('\nTest set: Accuracy: {}/{}({:.1f}%)\n'.format(correct2, len(test_loader.dataset),100. * correct2 / len(test_loader.dataset)))#画损失函数图
plt.plot(accuracy1,label='Train Set')
plt.plot(accuracy2,label='Test Set')
plt.legend(loc=4,ncol=1)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('GAPNet Recognition Accuracy')
plt.show()  

 

这篇关于GAPNet手写体数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/919207

相关文章

从去中心化到智能化:Web3如何与AI共同塑造数字生态

在数字时代的演进中,Web3和人工智能(AI)正成为塑造未来互联网的两大核心力量。Web3的去中心化理念与AI的智能化技术,正相互交织,共同推动数字生态的变革。本文将探讨Web3与AI的融合如何改变数字世界,并展望这一新兴组合如何重塑我们的在线体验。 Web3的去中心化愿景 Web3代表了互联网的第三代发展,它基于去中心化的区块链技术,旨在创建一个开放、透明且用户主导的数字生态。不同于传统

阿里开源语音识别SenseVoiceWindows环境部署

SenseVoice介绍 SenseVoice 专注于高精度多语言语音识别、情感辨识和音频事件检测多语言识别: 采用超过 40 万小时数据训练,支持超过 50 种语言,识别效果上优于 Whisper 模型。富文本识别:具备优秀的情感识别,能够在测试数据上达到和超过目前最佳情感识别模型的效果。支持声音事件检测能力,支持音乐、掌声、笑声、哭声、咳嗽、喷嚏等多种常见人机交互事件进行检测。高效推

usaco 1.2 Name That Number(数字字母转化)

巧妙的利用code[b[0]-'A'] 将字符ABC...Z转换为数字 需要注意的是重新开一个数组 c [ ] 存储字符串 应人为的在末尾附上 ‘ \ 0 ’ 详见代码: /*ID: who jayLANG: C++TASK: namenum*/#include<stdio.h>#include<string.h>int main(){FILE *fin = fopen (

AIGC6: 走进腾讯数字盛会

图中是一个程序员,去参加一个技术盛会。AI大潮下,五颜六色,各种不确定。 背景 AI对各行各业的冲击越来越大,身处职场的我也能清晰的感受到。 我所在的行业为全球客服外包行业。 业务模式为: 为国际跨境公司提供不同地区不同语言的客服外包解决方案,除了人力,还有软件系统。 软件系统主要是提供了客服跟客人的渠道沟通和工单管理,内部管理跟甲方的合同对接,绩效评估,BI数据透视。 客服跟客人

Clion不识别C代码或者无法跳转C语言项目怎么办?

如果是中文会显示: 此时只需要右击项目,或者你的源代码目录,将这个项目或者源码目录标记为项目源和头文件即可。 英文如下:

NC 把数字翻译成字符串

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 描述 有一种将字母编码成数字的方式:‘a’->1, ‘b->2’, … , ‘z->26’。 现在给一串数字,返回有多少种可能的译码结果 import java.u

34465A-61/2 数字万用表(六位半)

34465A-61/2 数字万用表(六位半) 文章目录 34465A-61/2 数字万用表(六位半)前言一、测DC/AC电压二、测DC/AC电流四、测电阻五、测电容六、测二极管七、保存截图流程 前言 1、6位半数字万用表通常具有200,000个计数器,可以显示最大为199999的数值。相比普通数字万用表,6位半万用表具有更高的测量分辨率和更高的测量准确度,适用于精度比较高的测

超级 密码加密 解密 源码,支持表情,符号,数字,字母,加密

超级 密码加密 解密 源码,支持表情,符号,数字,字母,加密 可以将表情,动物,水果,表情,手势,猫语,兽语,狗语,爱语,符号,数字,字母,加密和解密 可以将文字、字母、数字、代码、标点符号等内容转换成新的文字形式,通过简单的文字以不同的排列顺序来表达不同的内容 源码截图: https://www.httple.net/152649.html

两个长数字相加

1.编程题目 题目:要实现两个百位长的数字直接相加 分析:因为数字太长所以无法直接相加,所以采用按位相加,然后组装的方式。(注意进位) 2.编程实现 package com.sino.daily.code_2019_6_29;import org.apache.commons.lang3.StringUtils;/*** create by 2019-06-29 19:03** @autho

关于字符串转化为数字的深度优化两种算法

最近在做项目,在实际操作中发现自己在VC环境下写的字符串转化为整型的函数还是太过理想化了,或者说只能在window平台下软件环境中运行,重新给大家发两种函数方法: 第一个,就是理想化的函数,在VC环境下充分利用指针的优越性,对字符串转化为整型(同时也回答了某位网友的答案吖),实验检验通过: #include <stdio.h> #include <string.h> int rayatoi(c