Pytorch实战2:ResNet-18实现Cifar-10图像分类(测试集分类准确率95.170%)

本文主要是介绍Pytorch实战2:ResNet-18实现Cifar-10图像分类(测试集分类准确率95.170%),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

版权说明:此文章为本人原创内容,转载请注明出处,谢谢合作!


Pytorch实战2:ResNet-18实现Cifar-10图像分类

实验环境:

  1. Pytorch 0.4.0
  2. torchvision 0.2.1
  3. Python 3.6
  4. CUDA8+cuDNN v7 (可选)
  5. Win10+Pycharm

整个项目代码:点击这里

ResNet-18网络结构:

这里写图片描述
ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:
这里写图片描述!

Pytorch上搭建ResNet-18:

'''ResNet-18 Image classfication for cifar-10 with PyTorch Author 'Sun-qian'.'''
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ResidualBlock(nn.Module):def __init__(self, inchannel, outchannel, stride=1):super(ResidualBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(outchannel))self.shortcut = nn.Sequential()if stride != 1 or inchannel != outchannel:self.shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(outchannel))def forward(self, x):out = self.left(x)out += self.shortcut(x)out = F.relu(out)return outclass ResNet(nn.Module):def __init__(self, ResidualBlock, num_classes=10):super(ResNet, self).__init__()self.inchannel = 64self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(),)self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)self.fc = nn.Linear(512, num_classes)def make_layer(self, block, channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)   #strides=[1,1]layers = []for stride in strides:layers.append(block(self.inchannel, channels, stride))self.inchannel = channelsreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = F.avg_pool2d(out, 4)out = out.view(out.size(0), -1)out = self.fc(out)return outdef ResNet18():return ResNet(ResidualBlock)

Pytorch上训练:

所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。这里面有50000张用于训练,每个类5000张,另外10000用于测试,每个类1000张。训练时人为修改学习率,当epoch:[1-135] ,lr=0.1;epoch:[136-185], lr=0.01;epoch:[186-240] ,lr=0.001。训练代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
from resnet import ResNet18
import os# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 参数设置,使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #输出结果保存路径
args = parser.parse_args()# 超参数设置
EPOCH = 135   #遍历数据集次数
pre_epoch = 0  # 定义已经遍历数据集的次数
BATCH_SIZE = 128      #批处理尺寸(batch_size)
LR = 0.01        #学习率# 准备数据集并预处理
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),  #先四周填充0,在吧图像随机裁剪成32*32transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), #R,G,B每层的归一化用到的均值和方差
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) #训练数据集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)   #生成一个个batch进行批训练,组成batch的时候顺序打乱取testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# Cifar-10的标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 模型定义-ResNet
net = ResNet18().to(device)# 定义损失函数和优化方式
criterion = nn.CrossEntropyLoss()  #损失函数为交叉熵,多用于多分类问题
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4) #优化方式为mini-batch momentum-SGD,并采用L2正则化(权重衰减)# 训练
if __name__ == "__main__":if not os.path.exists(args.outf):os.makedirs(args.outf)best_acc = 85  #2 初始化best test accuracyprint("Start Training, Resnet-18!")  # 定义遍历数据集的次数with open("acc.txt", "w") as f:with open("log.txt", "w")as f2:for epoch in range(pre_epoch, EPOCH):print('\nEpoch: %d' % (epoch + 1))net.train()sum_loss = 0.0correct = 0.0total = 0.0for i, data in enumerate(trainloader, 0):# 准备数据length = len(trainloader)inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# forward + backwardoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 每训练1个batch打印一次loss和准确率sum_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += predicted.eq(labels.data).cpu().sum()print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% '% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))f2.write('%03d  %05d |Loss: %.03f | Acc: %.3f%% '% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))f2.write('\n')f2.flush()# 每训练完一个epoch测试一下准确率print("Waiting Test!")with torch.no_grad():correct = 0total = 0for data in testloader:net.eval()images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# 取得分最高的那个类 (outputs.data的索引号)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()print('测试分类准确率为:%.3f%%' % (100 * correct / total))acc = 100. * correct / total# 将每次测试结果实时写入acc.txt文件中print('Saving model......')torch.save(net.state_dict(), '%s/net_%03d.pth' % (args.outf, epoch + 1))f.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, acc))f.write('\n')f.flush()# 记录最佳测试分类准确率并写入best_acc.txt文件中if acc > best_acc:f3 = open("best_acc.txt", "w")f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))f3.close()best_acc = accprint("Training Finished, TotalEPOCH=%d" % EPOCH)

实验结果:best_acc= 95.170%

这里写图片描述
(损失图是matlab画的,用保存下来的txt日志)

这篇关于Pytorch实战2:ResNet-18实现Cifar-10图像分类(测试集分类准确率95.170%)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/725115

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

性能测试介绍

性能测试是一种测试方法,旨在评估系统、应用程序或组件在现实场景中的性能表现和可靠性。它通常用于衡量系统在不同负载条件下的响应时间、吞吐量、资源利用率、稳定性和可扩展性等关键指标。 为什么要进行性能测试 通过性能测试,可以确定系统是否能够满足预期的性能要求,找出性能瓶颈和潜在的问题,并进行优化和调整。 发现性能瓶颈:性能测试可以帮助发现系统的性能瓶颈,即系统在高负载或高并发情况下可能出现的问题

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

字节面试 | 如何测试RocketMQ、RocketMQ?

字节面试:RocketMQ是怎么测试的呢? 答: 首先保证消息的消费正确、设计逆向用例,在验证消息内容为空等情况时的消费正确性; 推送大批量MQ,通过Admin控制台查看MQ消费的情况,是否出现消费假死、TPS是否正常等等问题。(上述都是临场发挥,但是RocketMQ真正的测试点,还真的需要探讨) 01 先了解RocketMQ 作为测试也是要简单了解RocketMQ。简单来说,就是一个分

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time