卷积神经网络(CNN)使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类(代码➕注释)

本文主要是介绍卷积神经网络(CNN)使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类(代码➕注释),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

一、CNN概述

二、CNN网络结构

三、CNN常见名词

四、使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类


一、CNN概述

        卷积神经网络 ( Convolutional Neural NetworkCNN) 作为人工神经网络中一种常见的深度学习架构,该网络是受到生物自然视觉认知机制启发而来,是一种特殊的多层前馈神经网络, CNN 是由简单的神经网络改进而来,使用卷积层和池化层替代全连接层结构,卷积层能够有效地将图像中的各种特征提取出并生成特征图。广泛应用于图像识别图像分类等领域 ,具有良好的扩展性和鲁棒性,截至目前,CNN 的深度呈不断增加的趋势

        CNN在图像分类识别中要做的事情是:给定一张图片,图片中是牛还是马不知道,是什么牛也不知道,现在需要模型判断这张图片里具体是一个什么东西,总之输出一个结果:如果是牛的话,那是什么牛?

【1】鲁棒性也称作健壮性(英语:Robustness一个系统或组织有抵御或克服不利条件的能力。鲁棒性则常被用来描述可以面对复杂适应系统的能力,需要更全面的对系统进行考虑。

二、CNN网络结构

1)输入层(Input layer),众多神经元(Neuron)接受大量非线形输入讯息。输入的讯息称为输入向量。

2)卷积层:是一块一块地来进行比对。它拿来比对的这个“小块”我们称之为Features,每一个feature就像是一个小图,对图像和滤波矩阵做内积(逐个元素相乘再求和)的操作就是所谓的卷积”操作,也是卷积神经网络的名字来源。

【1】卷积:滤波器filter与数据窗口做内积(在CNN中,滤波器filter带着一组固定权重的神经元)对局部输入数据进行卷积计算。每计算完一个数据窗口内的局部数据后,数据窗口不断平移滑动,直到计算完所有数据

3)池化pool层:保留主要的特征进一步删减冗余参数,提高特征提取效率。池化,简言之,即取区域平均或最大。

5)全连接层:就是把特征整合到一起(高度提纯特征),方便交给最后的分类器或者回归。

三、CNN常见名词

1感受野:某一个输出层的一个元素对应输入层的区域大小,被称为感受野,即输出层的一个元素在输入层上的映射区域。

2激活函数:常用的非线性激活函数有sigmoidtanhrelu等等,前两者sigmoid/tanh比较常见于全连接层,后者relu常见于卷积层。

四、使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类

主要步骤是:

1. 加载和预处理CIFAR-10数据集
2. 定义卷积神经网络 ConvNet 模型
3. 定义交叉熵损失函数和SGD优化器
4. 训练模型50个epoch
5. 打印训练损失并完成训练

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
import torchvision.transforms as transforms
import matplotlib.pyplot as plt# 训练数据
transform = transforms.Compose([transforms.ToTensor(),     # 转为tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])   # 归一化trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)    # 测试数据    
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)  classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 卷积神经网络定义
class ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))    # 2层卷积池化x = self.pool(F.relu(self.conv2(x)))    # 2层卷积池化x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = ConvNet()
criterion = nn.CrossEntropyLoss()       # 损失函数定义
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)   # 优化器定义# 训练网络
for epoch in range(50):   # 50个epochrunning_loss = 0.0for i, data in enumerate(trainloader, 0):   # 遍历训练集inputs, labels = dataoptimizer.zero_grad()    # 梯度清零outputs = model(inputs)  # 神经网络前向传播loss = criterion(outputs, labels)    # 计算损失loss.backward()         # 反向传播optimizer.step()        # 更新参数running_loss += loss.item() # 累加损失loss = running_loss/len(trainset) # 打印Lossprint(f'Epoch {epoch+1}, Loss: {loss}') print('Finished Training')

这篇关于卷积神经网络(CNN)使用PyTorch实现卷积神经网络对CIFAR-10数据集进行图片分类(代码➕注释)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/679583

相关文章

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

java实现docker镜像上传到harbor仓库的方式

《java实现docker镜像上传到harbor仓库的方式》:本文主要介绍java实现docker镜像上传到harbor仓库的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 前 言2. 编写工具类2.1 引入依赖包2.2 使用当前服务器的docker环境推送镜像2.2

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Java easyExcel实现导入多sheet的Excel

《JavaeasyExcel实现导入多sheet的Excel》这篇文章主要为大家详细介绍了如何使用JavaeasyExcel实现导入多sheet的Excel,文中的示例代码讲解详细,感兴趣的小伙伴可... 目录1.官网2.Excel样式3.代码1.官网easyExcel官网2.Excel样式3.代码

Java中调用数据库存储过程的示例代码

《Java中调用数据库存储过程的示例代码》本文介绍Java通过JDBC调用数据库存储过程的方法,涵盖参数类型、执行步骤及数据库差异,需注意异常处理与资源管理,以优化性能并实现复杂业务逻辑,感兴趣的朋友... 目录一、存储过程概述二、Java调用存储过程的基本javascript步骤三、Java调用存储过程示

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

Go语言数据库编程GORM 的基本使用详解

《Go语言数据库编程GORM的基本使用详解》GORM是Go语言流行的ORM框架,封装database/sql,支持自动迁移、关联、事务等,提供CRUD、条件查询、钩子函数、日志等功能,简化数据库操作... 目录一、安装与初始化1. 安装 GORM 及数据库驱动2. 建立数据库连接二、定义模型结构体三、自动迁

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查