pytorch 参数冻结 parameter-efficient fine-tuning

2024-08-27 08:12

本文主要是介绍pytorch 参数冻结 parameter-efficient fine-tuning,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目标:在网络中冻结部分参数进行高效训练

框架:pytorch (version 1.11.0)

基本实现

  1. 需要学习的参数requires_grad设置为True,冻结的设置为False
  2. 需要学习的参数要加到 optimizer的List中;对于冻结的参数,可以直接不加进去,(应该也可以加进去,但是requires_grad=False)

注意事项
3. 如果不传递参数的层,记得前向操作是要设置 with torch.no_grad,否则即便没有需要更新的参数,其layer的梯度也回传,效率低。

  1. 要保证所有参与前向的操作,都被用于计算loss。例如,a=self.layer(b),只要前向里出现了这个操作,就要保证a(或a的后续输出)要参与loss的计算。如果a算完了不用,是不可以的。(不论self.layer里是否有需要更新的参数)。ps:这点和不冻结设置下的要求不一样,如果所有参数都学,即便中间有一些变量操作是冗余的,也不会报错,只是增加计算代价而已。(比如,在clip框架里,如果不用text prompt, 就不要提取该特征)
  2. 要保证,所有需要更新的参数,都用于前向计算了。如何比较二者的参数,见下:

a. 记录需要梯度回传的参数:

grad_params = set()
for name, param in model.named_parameters():if param.requires_grad:grad_params.add(name)

b. 记录前向中使用的参数:

used_params = set()
def forward(self, x):for name, param in self.named_parameters():if param.requires_grad:param.register_hook(lambda grad, name=name: used_params.add(name))return self.model(x)

c. 比较二者差异

unused_params = grad_params - used_params
if unused_params:print("以下参数未在 forward 函数中使用:", unused_params)
else:print("所有需要计算梯度的参数都在 forward 函数中使用了。")

ps. 好像也可以通过在nn.parallel.DistributedDataParallel中设置find_unused_parameters=True来找到未使用的变量。(不过我没试过

这篇关于pytorch 参数冻结 parameter-efficient fine-tuning的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1111115

相关文章

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

SpringMVC获取请求参数的方法

《SpringMVC获取请求参数的方法》:本文主要介绍SpringMVC获取请求参数的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下... 目录1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、@RequestParam4、@

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

Spring Boot项目部署命令java -jar的各种参数及作用详解

《SpringBoot项目部署命令java-jar的各种参数及作用详解》:本文主要介绍SpringBoot项目部署命令java-jar的各种参数及作用的相关资料,包括设置内存大小、垃圾回收... 目录前言一、基础命令结构二、常见的 Java 命令参数1. 设置内存大小2. 配置垃圾回收器3. 配置线程栈大小

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

SpringBoot利用@Validated注解优雅实现参数校验

《SpringBoot利用@Validated注解优雅实现参数校验》在开发Web应用时,用户输入的合法性校验是保障系统稳定性的基础,​SpringBoot的@Validated注解提供了一种更优雅的解... 目录​一、为什么需要参数校验二、Validated 的核心用法​1. 基础校验2. php分组校验3

一文带你了解SpringBoot中启动参数的各种用法

《一文带你了解SpringBoot中启动参数的各种用法》在使用SpringBoot开发应用时,我们通常需要根据不同的环境或特定需求调整启动参数,那么,SpringBoot提供了哪些方式来配置这些启动参... 目录一、启动参数的常见传递方式二、通过命令行参数传递启动参数三、使用 application.pro

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

基于@RequestParam注解之Spring MVC参数绑定的利器

《基于@RequestParam注解之SpringMVC参数绑定的利器》:本文主要介绍基于@RequestParam注解之SpringMVC参数绑定的利器,具有很好的参考价值,希望对大家有所帮助... 目录@RequestParam注解:Spring MVC参数绑定的利器什么是@RequestParam?@