AlexNet-pytorch实现

2024-04-19 07:48
文章标签 实现 pytorch alexnet

本文主要是介绍AlexNet-pytorch实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

AlexNet

1.网络架构

如图所示可见其结构为:

img

AlexNet网络共八层,五层卷积层和三层全连接层。这是一个非常经典的设计,为后续神经网络的发展提供了极大的贡献。

2.pytorch网络设计

网络设计部分做了一些小的修改,目的是为了适配minist的3x28x28的输入图片大小。

网络构造代码部分:

class AlexNet(nn.Module):def __init__(self):super(AlexNet, self).__init__()self.conv = nn.Sequential(nn.Conv2d(3, 96, 11, 1, 5),  # in_channels, out_channels, kernel_size, stride, paddingnn.ReLU(),nn.MaxPool2d(3, 1),  # kernel_size, stride 26x26# 减少卷积窗口,使用填充为2来使输入输出大小一致nn.Conv2d(96, 256, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(4, 2),  # 12x12# 下面接三个卷积层nn.Conv2d(256, 384, 3, 1, 1),nn.ReLU(),nn.Conv2d(384, 384, 3, 1, 1),nn.ReLU(),nn.Conv2d(384, 256, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(4, 2)  # 5x5)self.fc = nn.Sequential(nn.Linear(256 * 5 * 5, 4096),nn.Dropout(0.5),nn.Linear(4096, 4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 10),)def forward(self, img):img.shape[0]# img.resize_(3,224,224)feature = self.conv(img)output = self.fc(feature.view(img.shape[0], -1))return output

3.网络测试

一些基础设置与上一篇文章一致,还是贴一下代码。

网络测试部分我使用的是minist数据集,为了贴近真实(主要是方便我自己懂),在下载了数据集之后将其转为了图片数据集,更为直观。数据集分为train 和test两部分,在测试中需要做如下配置:

1.依赖资源引入

draw_tool是一个自己编写的绘制loss,acc的画图库,device使用了我电脑的1050ti显卡。

import torch
from matplotlib import pyplot as plt
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch import nn
from torchsummary import summary
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import draw_toolroot = "F:/pycharm/dataset/mnist/MNIST/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
draw = draw_tool.draw_tool()

2.数据集的读取和分类

#加载图片
def default_loader(path):return Image.open(path).convert('RGB')#构造标注和图片相关
class MyDataset(Dataset):def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):fh = open(txt, 'r')imgs = []for line in fh:line = line.strip('\n')line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1])))self.imgs = imgsself.transform = transformself.target_transform = target_transformself.loader = loaderdef __getitem__(self, index):fn, label = self.imgs[index]img = self.loader(fn)if self.transform is not None:img = self.transform(img)return img, labeldef __len__(self):return len(self.imgs)train_data = MyDataset(txt=root + 'rawtrain.txt', transform=transforms.ToTensor())
test_data = MyDataset(txt=root + 'rawtest.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=31, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=31, shuffle=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

3.模型训练设置

model = AlexNet()
#使用softmax分类
criterion = torch.nn.CrossEntropyLoss()
#设置随机梯度下降 学习率和L2正则
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
#使用GPU训练
model = model.to(device)

4.训练

每训练一个epoch 做一次平均loss train acc test acc的计算绘制

def train(epoch):running_loss = 0.0num_correct = 0.0total = 0correct = 0total = 0test_acc = 0.0# trainfor batch_idx, data in enumerate(train_loader, 0):inputs, target = datainputs = inputs.to(device)target = target.to(device)optimizer.zero_grad()# forward + backward + updateoutputs = model(inputs)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, dim=1)total += target.size(0)num_correct += (predicted == target).sum().item()# #test# with torch.no_grad():#     for data in test_loader:#         images, labels = data#         images = images.to(device)#         labels = labels.to(device)#         outputs = model(images)#         _, predicted = torch.max(outputs.data, dim=1)#         total += labels.size(0)##         correct += (predicted == labels).sum().item()print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / len(train_loader)))# print('Accuracy on test set: %d %%' % (100 * correct / total))# test_acc=100 * correct / totaltest_acc = test()acc = (num_correct / len(train_loader.dataset) * 100)print("num_correct=")print(acc)running_loss /= len(train_loader)draw.new_data(running_loss, acc, test_acc, 2)draw.draw()def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc = 100 * correct / totalprint('Accuracy on test set: ', test_acc, '%')return test_acc

5.结果统计

if __name__ == '__main__':for epoch in range(20):train(epoch)torch.save(model.state_dict(), "minist_last.pth")draw.show()

在这里插入图片描述

从图中效果可以看到随着训练次数的增加,loss在不断下降,train acc 和test acc 也在慢慢收敛,最终达到了train acc=97% test acc=96%的效果。但与之前上一文的训练有一样的问题所在,不知道为什么中途的test acc会突然下降,这里就不在往下继续训练了,网络变得更为复杂并不代表精度一定会上升,反而对于简单数据的预测来说,只会更差。

留下一个问题,就是为什么我的test acc 会突然下滑这么多,如果有朋友有自己的想法或者有大佬愿意回复我一下还请评论一下,谢谢。

这篇关于AlexNet-pytorch实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/916958

相关文章

C++对象布局及多态实现探索之内存布局(整理的很多链接)

本文通过观察对象的内存布局,跟踪函数调用的汇编代码。分析了C++对象内存的布局情况,虚函数的执行方式,以及虚继承,等等 文章链接:http://dev.yesky.com/254/2191254.shtml      论C/C++函数间动态内存的传递 (2005-07-30)   当你涉及到C/C++的核心编程的时候,你会无止境地与内存管理打交道。 文章链接:http://dev.yesky

通过SSH隧道实现通过远程服务器上外网

搭建隧道 autossh -M 0 -f -D 1080 -C -N user1@remotehost##验证隧道是否生效,查看1080端口是否启动netstat -tuln | grep 1080## 测试ssh 隧道是否生效curl -x socks5h://127.0.0.1:1080 -I http://www.github.com 将autossh 设置为服务,隧道开机启动

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测 目录 时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测基本介绍程序设计参考资料 基本介绍 MATLAB实现LSTM时间序列未来多步预测-递归预测。LSTM是一种含有LSTM区块(blocks)或其他的一种类神经网络,文献或其他资料中LSTM区块可能被描述成智能网络单元,因为

vue项目集成CanvasEditor实现Word在线编辑器

CanvasEditor实现Word在线编辑器 官网文档:https://hufe.club/canvas-editor-docs/guide/schema.html 源码地址:https://github.com/Hufe921/canvas-editor 前提声明: 由于CanvasEditor目前不支持vue、react 等框架开箱即用版,所以需要我们去Git下载源码,拿到其中两个主

android一键分享功能部分实现

为什么叫做部分实现呢,其实是我只实现一部分的分享。如新浪微博,那还有没去实现的是微信分享。还有一部分奇怪的问题:我QQ分享跟QQ空间的分享功能,我都没配置key那些都是原本集成就有的key也可以实现分享,谁清楚的麻烦详解下。 实现分享功能我们可以去www.mob.com这个网站集成。免费的,而且还有短信验证功能。等这分享研究完后就研究下短信验证功能。 开始实现步骤(新浪分享,以下是本人自己实现

基于Springboot + vue 的抗疫物质管理系统的设计与实现

目录 📚 前言 📑摘要 📑系统流程 📚 系统架构设计 📚 数据库设计 📚 系统功能的具体实现    💬 系统登录注册 系统登录 登录界面   用户添加  💬 抗疫列表展示模块     区域信息管理 添加物资详情 抗疫物资列表展示 抗疫物资申请 抗疫物资审核 ✒️ 源码实现 💖 源码获取 😁 联系方式 📚 前言 📑博客主页:

探索蓝牙协议的奥秘:用ESP32实现高质量蓝牙音频传输

蓝牙(Bluetooth)是一种短距离无线通信技术,广泛应用于各种电子设备之间的数据传输。自1994年由爱立信公司首次提出以来,蓝牙技术已经经历了多个版本的更新和改进。本文将详细介绍蓝牙协议,并通过一个具体的项目——使用ESP32实现蓝牙音频传输,来展示蓝牙协议的实际应用及其优点。 蓝牙协议概述 蓝牙协议栈 蓝牙协议栈是蓝牙技术的核心,定义了蓝牙设备之间如何进行通信。蓝牙协议

python实现最简单循环神经网络(RNNs)

Recurrent Neural Networks(RNNs) 的模型: 上图中红色部分是输入向量。文本、单词、数据都是输入,在网络里都以向量的形式进行表示。 绿色部分是隐藏向量。是加工处理过程。 蓝色部分是输出向量。 python代码表示如下: rnn = RNN()y = rnn.step(x) # x为输入向量,y为输出向量 RNNs神经网络由神经元组成, python

基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别

转发来源:https://swift.ctolib.com/ooooverflow-chinese-ocr.html chinese-ocr 基于CTPN(tensorflow)+CRNN(pytorch)+CTC的不定长文本检测和识别 环境部署 sh setup.sh 使用环境: python 3.6 + tensorflow 1.10 +pytorch 0.4.1 注:CPU环境

利用Frp实现内网穿透(docker实现)

文章目录 1、WSL子系统配置2、腾讯云服务器安装frps2.1、创建配置文件2.2 、创建frps容器 3、WSL2子系统Centos服务器安装frpc服务3.1、安装docker3.2、创建配置文件3.3 、创建frpc容器 4、WSL2子系统Centos服务器安装nginx服务 环境配置:一台公网服务器(腾讯云)、一台笔记本电脑、WSL子系统涉及知识:docker、Frp