Pytorch中的Exponential Moving Average(EMA)

2023-10-14 16:20

本文主要是介绍Pytorch中的Exponential Moving Average(EMA),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

EMA介绍

EMA,指数移动平均,常用于更新模型参数、梯度等。

EMA的优点是能提升模型的鲁棒性(融合了之前的模型权重信息)

代码示例

下面以yolov7/utils/torch_utils.py代码为例:

class ModelEMA:""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-modelsKeep a moving average of everything in the model state_dict (parameters and buffers).This is intended to allow functionality likehttps://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverageA smoothed version of the weights is necessary for some training schemes to perform well.This class is sensitive where it is initialized in the sequence of model init,GPU assignment and distributed training wrappers."""def __init__(self, model, decay=0.9999, updates=0):# Create EMAself.ema = deepcopy(model.module if is_parallel(model) else model).eval()self.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / 2000))for p in self.ema.parameters():p.requires_grad_(False)def update(self, model):# Update EMA parameterswith torch.no_grad():self.updates += 1d = self.decay(self.updates)msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  for k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:v *= dv += (1. - d) * msd[k].detach()def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):# Update EMA attributescopy_attr(self.ema, model, include, exclude)

ModelEMA类的__init__ 函数介绍

__init__ 函数的输入参数介绍

  • model:需要使用EMA策略更新参数的模型
  • decay:加权权重,默认为0.9999
  • updates:模型参数更新/迭代次数

__init__ 函数的初始化介绍

首先深拷贝一份模型

"""
创建EMA模型model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

接着,初始化updates次数,若是从头开始训练,则该参数为0

self.updates = updates

最后,定义加权权重decay的计算公式(这里呈指数型变化),

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

ModelEMA类的update()函数介绍

如果调用该函数,则更新updates以及decay,

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

取出当前模型的参数,为更新EMA模型的参数做准备,

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

对EMA模型参数以及当前模型参数进行加权求和,作为EMA模型的新参数,

for k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:v *= dv += (1. - d) * msd[k].detach()

参考文章

【代码解读】在pytorch中使用EMA - 知乎

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

以史为鉴!EMA在机器学习中的应用 - 知乎

这篇关于Pytorch中的Exponential Moving Average(EMA)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/211675

相关文章

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl