Pytorch中的Exponential Moving Average(EMA)

2023-10-14 16:20

本文主要是介绍Pytorch中的Exponential Moving Average(EMA),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

EMA介绍

EMA,指数移动平均,常用于更新模型参数、梯度等。

EMA的优点是能提升模型的鲁棒性(融合了之前的模型权重信息)

代码示例

下面以yolov7/utils/torch_utils.py代码为例:

class ModelEMA:""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-modelsKeep a moving average of everything in the model state_dict (parameters and buffers).This is intended to allow functionality likehttps://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverageA smoothed version of the weights is necessary for some training schemes to perform well.This class is sensitive where it is initialized in the sequence of model init,GPU assignment and distributed training wrappers."""def __init__(self, model, decay=0.9999, updates=0):# Create EMAself.ema = deepcopy(model.module if is_parallel(model) else model).eval()self.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / 2000))for p in self.ema.parameters():p.requires_grad_(False)def update(self, model):# Update EMA parameterswith torch.no_grad():self.updates += 1d = self.decay(self.updates)msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  for k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:v *= dv += (1. - d) * msd[k].detach()def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):# Update EMA attributescopy_attr(self.ema, model, include, exclude)

ModelEMA类的__init__ 函数介绍

__init__ 函数的输入参数介绍

  • model:需要使用EMA策略更新参数的模型
  • decay:加权权重,默认为0.9999
  • updates:模型参数更新/迭代次数

__init__ 函数的初始化介绍

首先深拷贝一份模型

"""
创建EMA模型model.eval()的作用:
1. 保证BN层使用的是训练数据的均值(running_mean)和方差(running_val), 否则一旦test的batch_size过小, 很容易就会被BN层影响结果
2. 保证Dropout不随机舍弃神经元
3. 模型不会计算梯度,从而减少内存消耗和计算时间is_parallel()的作用:
如果模型是并行训练(DP/DDP)的, 则深拷贝model.module,否则就深拷贝model"""
self.ema = deepcopy(model.module if is_parallel(model) else model).eval()

接着,初始化updates次数,若是从头开始训练,则该参数为0

self.updates = updates

最后,定义加权权重decay的计算公式(这里呈指数型变化),

self.decay = lambda x: decay * (1 - math.exp(-x / 2000))

ModelEMA类的update()函数介绍

如果调用该函数,则更新updates以及decay,

self.updates += 1
## d随着updates的增加而逐渐增大, 意味着随着模型迭代次数的增加, EMA模型的权重会越来越偏向于之前的权重
d = self.decay(self.updates)

取出当前模型的参数,为更新EMA模型的参数做准备,

msd = model.module.state_dict() if is_parallel(model) else model.state_dict()

对EMA模型参数以及当前模型参数进行加权求和,作为EMA模型的新参数,

for k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:v *= dv += (1. - d) * msd[k].detach()

参考文章

【代码解读】在pytorch中使用EMA - 知乎

【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎

以史为鉴!EMA在机器学习中的应用 - 知乎

这篇关于Pytorch中的Exponential Moving Average(EMA)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/211675

相关文章

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用

【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

3、cnn基础 卷积神经网络 输入层 —输入图片矩阵 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片 卷积层 —特征提取 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆 卷积操作 激活层 —加强特征 池化层 —压缩数据 全连接层 —进行分类 输出层 —输出分类概率 4、基于LeNet

pytorch torch.nn.functional.one_hot函数介绍

torch.nn.functional.one_hot 是 PyTorch 中用于生成独热编码(one-hot encoding)张量的函数。独热编码是一种常用的编码方式,特别适用于分类任务或对离散的类别标签进行处理。该函数将整数张量的每个元素转换为一个独热向量。 函数签名 torch.nn.functional.one_hot(tensor, num_classes=-1) 参数 t

pytorch计算网络参数量和Flops

from torchsummary import summarysummary(net, input_size=(3, 256, 256), batch_size=-1) 输出的参数是除以一百万(/1000000)M, from fvcore.nn import FlopCountAnalysisinputs = torch.randn(1, 3, 256, 256).cuda()fl

HumanNeRF:Free-viewpoint Rendering of Moving People from Monocular Video 翻译

HumanNeRF:单目视频中运动人物的自由视点绘制 引言。我们介绍了一种自由视点渲染方法- HumanNeRF -它适用于一个给定的单眼视频ofa人类执行复杂的身体运动,例如,从YouTube的视频。我们的方法可以在任何帧暂停视频,并从任意新的摄像机视点或甚至针对该特定帧和身体姿势的完整360度摄像机路径渲染主体。这项任务特别具有挑战性,因为它需要合成身体的照片级真实感细节,如从输入视频中可能

Python(TensorFlow和PyTorch)两种显微镜成像重建算法模型(显微镜学)

🎯要点 🎯受激发射损耗显微镜算法模型:🖊恢复嘈杂二维和三维图像 | 🖊模型架构:恢复上下文信息和超分辨率图像 | 🖊使用嘈杂和高信噪比的图像训练模型 | 🖊准备半合成训练集 | 🖊优化沙邦尼尔损失和边缘损失 | 🖊使用峰值信噪比、归一化均方误差和多尺度结构相似性指数量化结果 | 🎯训练荧光显微镜模型和对抗网络图形转换模型 🍪语言内容分比 🍇Python图像归一化

Pytorch环境搭建时的各种问题

1 问题 1.一直soving environment,跳不出去。网络解决方案有:配置清华源,更新conda等,没起作用。2.下载完后,有3个要done的东西,最后那个exe开头的(可能吧),总是报错。网络解决方案有:用管理员权限打开prompt等,没起作用。3.有时候配置完源,安装包的时候显示什么https之类的东西,去c盘的用户那个文件夹里找到".condarc"文件把里面的网址都改成htt

【PyTorch】使用容器(Containers)进行网络层管理(Module)

文章目录 前言一、Sequential二、ModuleList三、ModuleDict四、ParameterList & ParameterDict总结 前言 当深度学习模型逐渐变得复杂,在编写代码时便会遇到诸多麻烦,此时便需要Containers的帮助。Containers的作用是将一部分网络层模块化,从而更方便地管理和调用。本文介绍PyTorch库常用的nn.Sequen

【python pytorch】Pytorch实现逻辑回归

pytorch 逻辑回归学习demo: import torchimport torch.nn as nnimport torchvision.datasets as dsetsimport torchvision.transforms as transformsfrom torch.autograd import Variable# Hyper Parameters input_si

【python pytorch】Pytorch 基础知识

包含知识点: 张量数学操作数理统计比较操作 #-*-coding:utf-8-*-import numpy as npnp.set_printoptions(suppress=True)import torch# 构造一个4*5 的矩阵z=torch.Tensor(4,5)print(z)# 两个矩阵进行加法操作y=torch.rand(4,5)print(z+y)# 另一种表示