Pytorch入门(7)—— 梯度累加(Gradient Accumulation)

2024-05-24 20:52

本文主要是介绍Pytorch入门(7)—— 梯度累加(Gradient Accumulation),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. 梯度累加

  • 在训练大模型时,batch_size 最大值往往受限于显存容量上限,当模型非常大时,这个上限可能小到不可接受。梯度累加(Gradient Accumulation)是一个解决该问题的 trick
  • 梯度累加的思想很简单,就是时间换空间。具体而言,我们不在每个 batch data 梯度计算后直接更新模型,而是多算几个 batch 后,使用这些 batch 的平均梯度更新模型,从而放大等效 batch_size。如下图所示
    在这里插入图片描述
  • 用公式表示:设 batch size 为 n n n,模型参数为 w \pmb{w} w,样本 i i i 的损失为 l i l_i li,则正常情况下 sgd 参数更新为
    w ← w + α ∑ i = 1 n 1 n ∂ l i ∂ w \pmb{w} \leftarrow \pmb{w} + \alpha \sum_{i=1}^n\frac{1}{n}\frac{\partial l_i}{\partial \pmb{w}} ww+αi=1nn1wli 使用梯度累加时,设累加步长为 m m m(即计算 m m m 个 batch 梯度后用梯度均值更新一次),sgd 更新如下
    w ← w + α 1 m ∑ b = 1 m ∑ i = 1 n 1 n ∂ l b i ∂ w = w + α ∑ i = 1 m n 1 m n ∂ l i ∂ w \begin{aligned} \pmb{w} &\leftarrow \pmb{w} + \alpha \frac{1}{m} \sum_{b=1}^m \sum_{i=1}^n\frac{1}{n}\frac{\partial l_{bi}}{\partial \pmb{w}} \\ &= \pmb{w} + \alpha \sum_{i=1}^{mn}\frac{1}{mn} \frac{\partial l_i}{\partial \pmb{w}} \end{aligned} ww+αm1b=1mi=1nn1wlbi=w+αi=1mnmn1wli 可见这等价于使用 batch_size = m n mn mn 进行训练

2. 在 pytorch 中实现梯度累加

2.1 伪代码

  • pytorch 使用和 tensor 绑定的自动微分机制。每个 tensor 对象都有 .grad 属性存储其中每个元素的梯度值,通过 .requires_grad 属性控制其是否参与梯度计算。训练模型时,一般通过对标量 loss 执行 loss.backward() 自动进行反向传播,以得到计算图中所有 tensor 的梯度。详见 PyTorch入门(2)—— 自动求梯度
  • pytorch 中梯度 tensor.grad 不会自动清零,而会在每次反向传播过程中自动累加,所以一般在反向传播前把梯度清零
    for inputs, labels in data_loader:# forward pass preds = model(inputs)loss  = criterion(preds, labels)# clear grad of last batch	optimizer.zero_grad()# backward pass, calculate grad of batch dataloss.backward()# update modeloptimizer.step()
    
    这种设计对于实现梯度累加 trick 是很方便的,我们可以在 batch 计算过程中进行计数,仅在达到计数达到更新步长时进行一次参数更新并清零梯度,即
    # batch accumulation parameter
    accum_iter = 4  # loop through enumaretad batches
    for batch_idx, (inputs, labels) in enumerate(data_loader):# forward pass preds = model(inputs)loss  = criterion(preds, labels)# scale the loss to the mean of the accumulated batch sizeloss = loss / accum_iter # backward passloss.backward()# weights updateif ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_loader)):optimizer.step()optimizer.zero_grad()
    

2.2 线性回归案例

  • 下面使用来自 经典机器学习方法(1)—— 线性回归 的简单线性回归任务说明梯度累加的具体实现方法

    本节代码直接从 jupyter notebook 复制而来,可能无法直接运行!

  • 首先生成随机数据构造 dataset
    import torch
    from IPython import display
    from matplotlib import pyplot as plt
    import numpy as np
    import random
    import torch.utils.data as Data
    import torch.nn as nn
    import torch.optim as optim# 生成样本
    num_inputs = 2
    num_examples = 1000
    true_w = torch.Tensor([-2,3.4]).view(2,1)
    true_b = 4.2
    batch_size = 10# 1000 个2特征样本,每个特征都服从 N(0,1)
    features = torch.randn(num_examples, num_inputs, dtype=torch.float32) # 生成真实标记
    labels = torch.mm(features,true_w) + true_b
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float32)# 包装数据集,将训练数据的特征和标签组合
    dataset = Data.TensorDataset(features, labels)
    
    1. 不使用梯度累加技巧,batch size 设置为 40
      # 构造 DataLoader
      batch_size = 40
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失函数
      criterion = nn.MSELoss()# SGD优化器
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for X, y in data_iter:# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))# 梯度清零optimizer.zero_grad()            # 计算各参数梯度loss.backward()#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')'''
      epoch 1, loss: 0.5434057731628418
      epoch 2, loss: 0.1914414196014404
      epoch 3, loss: 0.06752514398097992
      '''
      
    2. 使用梯度累加,batch size 设置为 10,步长设为 4,等效 batch size 为 40
      # 构造 DataLoader
      batch_size = 10
      accum_iter = 4
      data_iter = Data.DataLoader(dataset, batch_size, shuffle=False)	# shuffle=False 保证实验可比# 定义模型
      net = nn.Sequential(nn.Linear(num_inputs, 1))# 初始化模型参数
      nn.init.normal_(net[0].weight, mean=0, std=0)
      nn.init.constant_(net[0].bias, val=0)# 均方差损失
      criterion = nn.MSELoss()# SGD优化器对象
      optimizer = optim.SGD(net.parameters(), lr=0.01)# 模型训练
      num_epochs = 3
      for epoch in range(1, num_epochs + 1):epoch_loss = []for batch_idx, (X, y) in enumerate(data_iter):# 正向传播,计算损失output = net(X) loss = criterion(output, y.view(-1, 1))  loss = loss / accum_iter	# 取各个累计batch的平均损失,从而在.backward()时得到平均梯度# 反向传播,梯度累计loss.backward()if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(data_iter)):#print('backward: ', net[0].weight.grad)# 更新模型optimizer.step()              # 梯度清零optimizer.zero_grad()epoch_loss.append(loss.item()/batch_size)print(f'epoch {epoch}, loss: {np.mean(epoch_loss)}')
      '''
      epoch 1, loss: 0.5434057596921921
      epoch 2, loss: 0.19144139245152472
      epoch 3, loss: 0.06752512042224407
      '''
      
  • 可以观察到无论 epoch loss 还是 net[0].weight.grad 都完全相同,说明梯度累加不影响计算结果

这篇关于Pytorch入门(7)—— 梯度累加(Gradient Accumulation)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/999503

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

数论入门整理(updating)

一、gcd lcm 基础中的基础,一般用来处理计算第一步什么的,分数化简之类。 LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; } <pre name="code" class="cpp">LL lcm(LL a, LL b){LL c = gcd(a, b);return a / c * b;} 例题:

Java 创建图形用户界面(GUI)入门指南(Swing库 JFrame 类)概述

概述 基本概念 Java Swing 的架构 Java Swing 是一个为 Java 设计的 GUI 工具包,是 JAVA 基础类的一部分,基于 Java AWT 构建,提供了一系列轻量级、可定制的图形用户界面(GUI)组件。 与 AWT 相比,Swing 提供了许多比 AWT 更好的屏幕显示元素,更加灵活和可定制,具有更好的跨平台性能。 组件和容器 Java Swing 提供了许多

【IPV6从入门到起飞】5-1 IPV6+Home Assistant(搭建基本环境)

【IPV6从入门到起飞】5-1 IPV6+Home Assistant #搭建基本环境 1 背景2 docker下载 hass3 创建容器4 浏览器访问 hass5 手机APP远程访问hass6 更多玩法 1 背景 既然电脑可以IPV6入站,手机流量可以访问IPV6网络的服务,为什么不在电脑搭建Home Assistant(hass),来控制你的设备呢?@智能家居 @万物互联

poj 2104 and hdu 2665 划分树模板入门题

题意: 给一个数组n(1e5)个数,给一个范围(fr, to, k),求这个范围中第k大的数。 解析: 划分树入门。 bing神的模板。 坑爹的地方是把-l 看成了-1........ 一直re。 代码: poj 2104: #include <iostream>#include <cstdio>#include <cstdlib>#include <al

MySQL-CRUD入门1

文章目录 认识配置文件client节点mysql节点mysqld节点 数据的添加(Create)添加一行数据添加多行数据两种添加数据的效率对比 数据的查询(Retrieve)全列查询指定列查询查询中带有表达式关于字面量关于as重命名 临时表引入distinct去重order by 排序关于NULL 认识配置文件 在我们的MySQL服务安装好了之后, 会有一个配置文件, 也就

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

C语言指针入门 《C语言非常道》

C语言指针入门 《C语言非常道》 作为一个程序员,我接触 C 语言有十年了。有的朋友让我推荐 C 语言的参考书,我不敢乱推荐,尤其是国内作者写的书,往往七拼八凑,漏洞百出。 但是,李忠老师的《C语言非常道》值得一读。对了,李老师有个官网,网址是: 李忠老师官网 最棒的是,有配套的教学视频,可以试看。 试看点这里 接下来言归正传,讲解指针。以下内容很多都参考了李忠老师的《C语言非

MySQL入门到精通

一、创建数据库 CREATE DATABASE 数据库名称; 如果数据库存在,则会提示报错。 二、选择数据库 USE 数据库名称; 三、创建数据表 CREATE TABLE 数据表名称; 四、MySQL数据类型 MySQL支持多种类型,大致可以分为三类:数值、日期/时间和字符串类型 4.1 数值类型 数值类型 类型大小用途INT4Bytes整数值FLOAT4By

【QT】基础入门学习

文章目录 浅析Qt应用程序的主函数使用qDebug()函数常用快捷键Qt 编码风格信号槽连接模型实现方案 信号和槽的工作机制Qt对象树机制 浅析Qt应用程序的主函数 #include "mywindow.h"#include <QApplication>// 程序的入口int main(int argc, char *argv[]){// argc是命令行参数个数,argv是