本文主要是介绍GAPNet手写体数字识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
仅作与好友分享
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms # datasets包含常用的数据集,transform 对图像进行预处理# training settings
batch_size = 60# MNIST Dataset,注意这里的关键工具,torch.utils, data.Dataloader,这个可以有效的读取数据,是一个得到batch的生成器
# 引入MNIST数据集通过datasets函数包进行导入
# root是数据的位置,train=True是下载训练有关的集合,download是决定下不下载数据,一斤固有数据集就download=Falsetrain_dataset = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())# Data Loader(Input Pipeline)是一个迭代器,torch.utils.data.DataLoader作用就是随机的在样本中选取数据组成一个小的batch。shuffle决定数据是否打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)
# 可视化数据图像
# for i in range(5):
# plt.figure()
# plt.imshow(train_loader.dataset.train_data[i].numpy())# # x = torch.randn(2, 2, 2)
# firstly change the data into diresed dimension, then reshape the tensor according to what I want
# x.view(-1, 1, 4)# 理解迭代器的深层含义,torch.utils.data.DataLoader的作用理解
# for (data, target) in train_loader:
# for i in range(4):
# plt.figure()
# print(target[1])
# plt.imshow(data[i].numpy()[0])
# breakclass LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(1, 24, 5) #pytorch文档,torch.nn.Conv2d函数参数定义self.conv2 = nn.Conv2d(24, 48, 5)self.conv3 = nn.Conv2d(in_channels=48,out_channels=32,kernel_size=5,padding=2)self.conv4 = nn.Conv2d(in_channels=32,out_channels=10,kernel_size=5,padding=2)def forward(self, x):x = F.max_pool2d(F.tanh(self.conv1(x)), (2, 2))x = F.dropout(x, p = 0.25, training=self.training)x = F.max_pool2d(F.tanh(self.conv2(x)), (2, 2))x = F.dropout(x, p = 0.25, training=self.training)x = F.tanh(self.conv3(x))x = F.dropout(x, p = 0.25, training=self.training)x = F.avg_pool2d(F.tanh(self.conv4(x)), (4, 4))x = x.view(-1, self.num_flat_features(x))return x# 定义num_flat_features函数进行尺度的变换def num_flat_features(self, x):size = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_features
# def pca(self,x):
# print(x)
# pca = PCA(n_components=2) #降到1维
# pca.fit(x) #训练
# x=pca.fit_transform(x) #降维后的数据
# return xmodel = LeNet5()
# state_dict = torch.load('1.pth')
# model.load_state_dict(state_dict=state_dict)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1, momentum=0.9)
criterion = nn.CrossEntropyLoss()def train(epoch):model.train() # 第一行固定,model.train是用来实现训练期间用的网络train_loss = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data), Variable(target)optimizer.zero_grad() # tidings清零output = model(data)loss = criterion(output, target)train_loss +=lossloss.backward() # 反向传播optimizer.step()
# if batch_idx % 10 == 0:
# print('Train Epoch:{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# epoch, batch_idx * len(data), len(train_loader.dataset),
# 100*batch_idx / len(train_loader), loss.item()))train_loss /= len(train_loader.dataset)print('Train Epoch:{} \tAverage Loss: {:.4f}'.format(epoch,train_loss.item()))return train_loss.item()def evaluate(data_loader):model.eval() # 测试期间用的网络loss = 0correct = 0# test数据集进行测试for data, target in data_loader:data, target = Variable(data, volatile=True), Variable(target)output = model(data)# sum up batch lossloss += criterion(output, target).item()# get the index of the max log-probabilitypred = output.data.max(1, keepdim=True)[1] # 预测输出的结果correct += pred.eq(target.data.view_as(pred)).cpu().sum()
# loss /= len(data_loader.dataset)return correct
Loss = []
accuracy1=[]
accuracy2=[]
for epoch in range(30):loss = train(epoch+1)Loss.append(loss)correct1=evaluate(train_loader)accuracy1.append(100. * correct1 / len(train_loader.dataset))print('\nTrain set: Accuracy: {}/{}({:.1f}%)\n'.format(correct1, len(train_loader.dataset),100. * correct1 / len(train_loader.dataset)))correct2=evaluate(test_loader)accuracy2.append(100. * correct2 / len(test_loader.dataset))print('\nTest set: Accuracy: {}/{}({:.1f}%)\n'.format(correct2, len(test_loader.dataset),100. * correct2 / len(test_loader.dataset)))#画损失函数图
plt.plot(accuracy1,label='Train Set')
plt.plot(accuracy2,label='Test Set')
plt.legend(loc=4,ncol=1)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('GAPNet Recognition Accuracy')
plt.show()
这篇关于GAPNet手写体数字识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!