学习基于pytorch的VGG图像分类 day2

2024-04-10 14:36

本文主要是介绍学习基于pytorch的VGG图像分类 day2,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.

目录

VGG网络搭建(模型文件)

        1.字典文件配置

         2.提取特征网络结构

        3. VGG类的定义

         4.VGG网络实例化


VGG网络搭建(模型文件)

        1.字典文件配置

#字典文件,对应各个配置,数字对应卷积核的个数,'M'对应最大液化(即maxpool)
cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

         2.提取特征网络结构

#提取特征网络结构
def make_features(cfg: list): #传入对应的列表layers = [] #定义一个空列表,存放每层的结果in_channels = 3 #输入为RGB彩色图片,输入通道为3for v in cfg: #通过for循环遍历列表if v == "M":                                                    #maxpool size = 2,stride = 2layers += [nn.MaxPool2d(kernel_size=2, stride=2)] #创建最大池化下载量程,池化核为2,布局也为2else:                                                           #conv padding = 1,stride = 1conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) #创建卷积操作(输入特征矩阵深度,输出特征矩阵深度(卷积核个数),卷积核为3,填充为1,stride默认为1(不用写))layers += [conv2d, nn.ReLU(True)] #使用ReLU激活函数in_channels = v #输出深度改变成vreturn nn.Sequential(*layers) #通过Sequential函数将列表以非关键字参数的形式传入(*代表非关键字传入)

        3. VGG类的定义

class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False): #(通过make_features生成的提取特征网络结构,分类的类别个数,是否对网络权重初始化)super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential( #生成分类网络nn.Linear(512*7*7, 4096), #全连接层上下的节点个数nn.ReLU(True),  #ReLU函数激活nn.Dropout(p=0.5), #Dropout函数减少过拟合,以50%的比例随机失活神经元nn.Linear(4096, 4096), #第一层和第二层nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes) #第二层和第三层,总计3层全连接层,最后连接到输出层,输出num_classes的所需个数)if init_weights: #初始化权重函数self._initialize_weights()def forward(self, x): #正向传播 x就是输入的图像数据 # N x 3 x 224 x 224x = self.features(x) #用features提取特征网络结构# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1) #对输出进行一个展平处理,(start_dim定义从哪个维度开始展平处理)# N x 512*7*7x = self.classifier(x) #输入到分类网络结构return xdef _initialize_weights(self):for m in self.modules(): #遍历网络的每一个子模块if isinstance(m, nn.Conv2d): #遍历到卷积层# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight) #使用xavier函数初始化,初始化卷积核的权重if m.bias is not None: #卷积核采用偏置nn.init.constant_(m.bias, 0) #将偏执初始化为0elif isinstance(m, nn.Linear): #遍历到全连接层,下面同理nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

         4.VGG网络实例化

#实例化VGG网络结构
def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs) #通过VGG这个类实现实例化网络,(**可变长度的字典变量)return model

 内容参考来源:

 ​​​​​​使用pytorch搭建VGG网络_哔哩哔哩_bilibili

这篇关于学习基于pytorch的VGG图像分类 day2的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/891335

相关文章

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

C#使用DeepSeek API实现自然语言处理,文本分类和情感分析

《C#使用DeepSeekAPI实现自然语言处理,文本分类和情感分析》在C#中使用DeepSeekAPI可以实现多种功能,例如自然语言处理、文本分类、情感分析等,本文主要为大家介绍了具体实现步骤,... 目录准备工作文本生成文本分类问答系统代码生成翻译功能文本摘要文本校对图像描述生成总结在C#中使用Deep

基于WinForm+Halcon实现图像缩放与交互功能

《基于WinForm+Halcon实现图像缩放与交互功能》本文主要讲述在WinForm中结合Halcon实现图像缩放、平移及实时显示灰度值等交互功能,包括初始化窗口的不同方式,以及通过特定事件添加相应... 目录前言初始化窗口添加图像缩放功能添加图像平移功能添加实时显示灰度值功能示例代码总结最后前言本文将

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。