PyTorch如何修改模型(魔改)

2024-04-30 13:36
文章标签 模型 pytorch 修改 魔改

本文主要是介绍PyTorch如何修改模型(魔改),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • PyTorch如何修改模型(魔改)
    • 1.修改模型层(模型框架⭐)
      • 1.1通过继承修改模型
      • 1.2通过组合修改模型(重点学👀)
      • 1.3通过猴子补丁修改模型
    • 2.添加外部输入
    • 3.添加额外输出
    • 参考

PyTorch如何修改模型(魔改)

对模型缝缝补补、修修改改,是我们必须要掌握的技能,本文详细介绍了如何修改PyTorch模型?也就是我们经常说的如何魔改。👍

PyTorch 的模型是一个 torch.nn.Module 的某个子类的对象,修改模型实际就等价于修改某个类,对面向对象熟悉的同学应该知道,对类做修改有两个经典的方法:组合继承

1.修改模型层(模型框架⭐)

1.1通过继承修改模型

首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,最后在子类中实现魔改的逻辑。其大致的框架代码如下所示:

from torchvision.models import ResNetclass CustomizedResNet(ResNet):def __init__(self):super().__init__()...def forward(self, x):...

下面这个例子,将对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet
import torch# 定义一个自定义的ResNet类,继承自torchvision的ResNet类
class CustomizedResNet(ResNet):def __init__(self, block, layers, num_classes=2):"""初始化函数block: ResNet中的基本块类型,可以是BasicBlock或Bottlenecklayers: 每个层级的基本块数量,是一个列表num_classes: 输出的类别数量,默认为2"""# 调用父类的初始化方法super().__init__(block, layers, num_classes)# 重新定义全连接层,改变输出的特征数量self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)def forward(self, x):# 以下是ResNet的前向传播过程x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)# 通过四个残差层x1 = self.layer1(x)x2 = self.layer2(x1)x3 = self.layer3(x2)x4 = self.layer4(x3)# 将四个残差层的输出进行拼接x = torch.cat([self.avgpool(x1),self.avgpool(x2),self.avgpool(x3),self.avgpool(x4),], dim=1)# 将拼接后的张量展平x = torch.flatten(x, 1)# 通过全连接层,得到最终的输出x = self.fc(x)return x# 创建不同版本的ResNet模型
new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)
new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)
new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)
new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)

1.2通过组合修改模型(重点学👀)

在面向对象编程中,可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。

首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module,然后将需要魔改的类作为类变量融入到模型中,下面是大致的框架代码:

from torchvision.models import resnet18
import torch.nn as nnclass CustomizedResNet(nn.Module):def __init__(self, backbone):super().__init__()self.backbone = backbone...def forward(self, x):...my_resnet18 = CustomizedResNet(resnet18)

同样,实现对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models import resnet50class CustomizedResNet(torch.nn.Module):def __init__(self, backbone, num_classes=2):super().__init__()self.backbone = backboneself.fc = torch.nn.Linear(3840, num_classes)def forward(self, x):x = self.backbone.conv1(x)x = self.backbone.bn1(x)x = self.backbone.relu(x)x = self.backbone.maxpool(x)x1 = self.backbone.layer1(x)x2 = self.backbone.layer2(x1)x3 = self.backbone.layer3(x2)x4 = self.backbone.layer4(x3)x = torch.cat([self.backbone.avgpool(x1),self.backbone.avgpool(x2),self.backbone.avgpool(x3),self.backbone.avgpool(x4),],dim=1,)x = torch.flatten(x, 1)x = self.fc(x)return xnew_resnet50 = CustomizedResNet(resnet50())

1.3通过猴子补丁修改模型

最简单粗暴的方法:猴子补丁(Monkey Patch)。之所以叫猴子补丁,是因为这种方法从程序设计的角度上来说,是具有破坏性的。而且这种方法仅能实现一些简单的修改需求,所以还是推荐使用继承或组合去修改我们的模型。😉

猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改。下面是把 ResNet34 的输出从 1000 改为 1 的简单例子:

from torchvision.models import resnet50
import torch.nn as nnmodel = resnet50()
model.fc = nn.Linear(2048, 1)

还有一个例子,以 PyTorch 官方视觉库 torchvision 预定义好的模型 ResNet50 为例,修改模型的某一层或者某几层。先观察一下它的网络结构:

import torch
import torch.nn as nn
from collections import OrderedDict
import torchvision.models as models
net = models.resnet50()
print(net)

假设要用这个模型去做一个10分类的问题,就应该修改模型的 fc 层,将其输出节点数替换为10。另外,想再加一层全连接层。可以做如下修改:

classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),('relu1', nn.ReLU()), ('dropout1',nn.Dropout(0.5)),('fc2', nn.Linear(128, 10)),('output', nn.Softmax(dim=1))]))net.fc = classifier

这里的操作相当于将模型(net)最后名称为“fc”的层替换成了名称为“classifier”的结构。

2.添加外部输入

有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加输入和后续层之间的连接关系,从而完成模型的修改。

以 torchvision 的 resnet50 模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量 add_variable 来辅助预测。具体实现如下:

class Model(nn.Module):def __init__(self, net):super().__init__()self.net = netself.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)self.fc_add = nn.Linear(1001, 10, bias=True)self.output = nn.Softmax(dim=1)def forward(self, x, add_variable):x = self.net(x)x = torch.cat((self.dropout(self.relu(x)),add_variable.unsqueeze(1)),1)x = self.fc_add(x)x = self.output(x)return x

这里的实现要点是通过torch.cat实现了tensor的拼接。torchvision 中的 resnet50 输出是一个1000维的 tensor,通过修改 forward 函数,先将 1000 维的 tensor 通过激活函数层和dropout层,再和外部输入变量"add_variable"拼接,最后通过全连接层映射到指定的输出维度 10。

另外这里对外部输入变量"add_variable"进行 unsqueeze 操作是为了和 net 输出的 tensor 保持维度一致,常用于 add_variable 是单一数值 (scalar) 的情况,此时 add_variable 的维度是 (batch_size, ),需要在第二维补充维数1,从而可以和 tensor 进行torch.cat操作。
unsqueeze与sequeeze语法说明

最后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()

另外别忘了,训练中在输入数据的时候要给两个inputs:

outputs = model(inputs, add_var)

3.添加额外输出

有时候在模型训练中,除了模型最后的输出外,我们需要输出模型某一中间层的结果,以施加额外的监督,获得更好的中间层结果。基本的思路是修改模型定义中 forward 函数的 return 变量。

依然以 resnet50 做 10 分类任务为例,在已经定义好的模型结构上,同时输出 1000 维的倒数第二层和 10 维的最后一层结果。具体实现如下:

class Model(nn.Module):def __init__(self, net):super().__init__()self.net = netself.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)self.fc1 = nn.Linear(1000, 10, bias=True)self.output = nn.Softmax(dim=1)def forward(self, x, add_variable):x1000 = self.net(x)x10 = self.dropout(self.relu(x1000))x10 = self.fc1(x10)x10 = self.output(x10)return x10, x1000

之后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()out10, out1000 = model(inputs, add_var)

参考

  • Chenglu’s Log

  • Pytorch修改预训练模型的方法汇总

😃😃😃

这篇关于PyTorch如何修改模型(魔改)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/949011

相关文章

Golang的CSP模型简介(最新推荐)

《Golang的CSP模型简介(最新推荐)》Golang采用了CSP(CommunicatingSequentialProcesses,通信顺序进程)并发模型,通过goroutine和channe... 目录前言一、介绍1. 什么是 CSP 模型2. Goroutine3. Channel4. Channe

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

python修改字符串值的三种方法

《python修改字符串值的三种方法》本文主要介绍了python修改字符串值的三种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录第一种方法:第二种方法:第三种方法:在python中,字符串对象是不可变类型,所以我们没办法直接

Mysql8.0修改配置文件my.ini的坑及解决

《Mysql8.0修改配置文件my.ini的坑及解决》使用记事本直接编辑my.ini文件保存后,可能会导致MySQL无法启动,因为MySQL会以ANSI编码读取该文件,解决方法是使用Notepad++... 目录Myhttp://www.chinasem.cnsql8.0修改配置文件my.ini的坑出现的问题

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

Retrieval-based-Voice-Conversion-WebUI模型构建指南

一、模型介绍 Retrieval-based-Voice-Conversion-WebUI(简称 RVC)模型是一个基于 VITS(Variational Inference with adversarial learning for end-to-end Text-to-Speech)的简单易用的语音转换框架。 具有以下特点 简单易用:RVC 模型通过简单易用的网页界面,使得用户无需深入了

透彻!驯服大型语言模型(LLMs)的五种方法,及具体方法选择思路

引言 随着时间的发展,大型语言模型不再停留在演示阶段而是逐步面向生产系统的应用,随着人们期望的不断增加,目标也发生了巨大的变化。在短短的几个月的时间里,人们对大模型的认识已经从对其zero-shot能力感到惊讶,转变为考虑改进模型质量、提高模型可用性。 「大语言模型(LLMs)其实就是利用高容量的模型架构(例如Transformer)对海量的、多种多样的数据分布进行建模得到,它包含了大量的先验

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}