代码与原理:混合精度训练详解

2024-08-27 12:44

本文主要是介绍代码与原理:混合精度训练详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

浮点数的表示

计算机是二进制的世界,所以浮点数也是用二进制来表示的,与整型不同的是,浮点数通过3个区间来表示,分别是:

  • sign 表示正负,1表示正数,0表示负数
  • exponent 用来确定数字的范围,这一部分有 k 个bit来表示二进制,所以 k 越大,浮点数能表示的范围就越大
  • fraction 部分用来确定精度,也是位数越多,能表示的精度就越高

比如:

  • BF16 一共 16bit,sign 占 1 bit,exponent 占8 bit,fraction占7bit
  • FP16 一共16bit,sign 占 1 bit,exponent 占5bit, fraction占10bit

BF16能表示的数字范围更大,但是表示的精度更低。FP16 表示的数字范围更小,但是表示的精度更高深度学习中长期使用的标准格式是FP32,因为它能平衡数值范围和精度,同时也有较好的硬件支持。

  • FP32一共32bit,sign 占 1 bit,exponent 占8 bit,fraction占23 bit

FP16存在的问题

float16和float32相比内存占用更少**,**通用的模型 fp16 占用的内存只需原来的一半,就意味着训练的时候可以用更大的batchsize,且在多卡训练时数据通信量大幅减少等待时间,还能加快计算节省模型的训练时间。但在模型的训练过程中,训练的稳定性很重要,如果用 FP16会出现如下问题:

  • 数据溢出(范围):在反向传播中,需要计算网络模型中权重的梯度(一阶导数),因此在加权后值会更小。由上图可知FP16相比FP32的有效范围要窄很多,使用FP16替换FP32会出现上溢(Overflow)和下溢(Underflow)的情况,实际中更容易出现下溢情况
  • 舍入误差(精度):是指当网络模型的反向梯度很小,一般FP32能够表示,但是转换到FP16会小于当前区间内的最小间隔,会导致数据溢出。如0.00006666666在FP32中能正常表示,转换到FP16后会表示成为0.000067,不满足FP16最小间隔的数会强制舍入产生误差

混合精度训练原理

为了想让深度学习训练可以使用FP16的好处,又要避免精度溢出和舍入误差。于是可以通过FP16和FP32的混合精度训练(Mixed-Precision),混合精度训练过程中可以引入权重备份(Weight Backup)、损失放大(Loss Scaling)、精度累加(Precision Accumulated)三种相关的技术。

权重备份(Weight Backup)

权重备份主要用于解决舍入误差的问题。其主要思路是把神经网络训练过程中产生的激活activations、梯度 gradients、中间变量等数据,在训练中都利用FP16来存储,同时复制一份FP32的权重参数weights,用于训练时候的更新。

权重用FP32格式备份一次,那岂不是使得内存占用反而更高了呢?是的,额外拷贝一份权重的确增加了训练时候内存的占用。但是实际上,在训练过程中内存中分为动态内存和静态内容,其中动态内存是静态内存的3-4倍,主要是中间变量值和激活activations的值。而这里备份的权重增加的主要是静态内存。只要动态内存的值基本都是使用FP16来进行存储,则最终模型与整网使用FP32进行训练相比起来, 内存占用也基本能够减半。

损失缩放(Loss Scaling)

因为梯度值太小,使用FP16表示有时会造成数据下溢出的问题,导致模型不收敛。为了解决梯度过小数据下溢的问题,对前向计算出来的Loss值进行放大操作,也就是把FP32的参数乘以某一个因子系数后,把可能溢出的小数位数据往前移,平移到FP16能表示的数据范围内。根据链式求导法则,放大Loss后会作用在反向传播的每一层梯度,这样比在每一层梯度上进行放大更加高效。损失放大是需要结合混合精度实现的,其主要的主要思路是:

  • Scale up阶段:网络模型前向计算后在反响传播前,将得到的损失变化值Loss增大2^K倍
  • Scale down阶段:反向传播后,将权重梯度缩2^K倍,恢复FP32值进行存储

精度累加(Precision Accumulated)

在混合精度的模型训练过程中,使用FP16进行矩阵乘法运算,利用FP32来进行矩阵乘法中间的累加(accumulated),然后再将FP32的值转化为FP16进行存储。简单而言,就是利用FP16进行矩阵相乘,利用FP32来进行加法计算弥补丢失的精度。这样可以有效减少计算过程中的舍入误差,尽量减缓精度损失的问题。

混合精度训练代码

下面是一个使用PyTorch进行混合精度训练的例子:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScalerclass SimpleMLP(nn.Module):def __init__(self):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 2)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x

启用混合精度:

model = SimpleMLP().cuda()
model.train()
scaler = GradScaler()for epoch in range(num_epochs):for batch in data_loader:x, y = batchx, y = x.cuda(), y.cuda()with autocast():outputs = model(x)loss = criterion(outputs, y)# 反向传播和权重更新# 放大梯度scaler.scale(loss).backward() # 应用缩放后的梯度进行权重更新scaler.step(optimizer)# 更新缩放因子scaler.update()

在这个例子中,autocast()将模型的前向传播和损失计算转换为FP16格式。然而,反向传播仍然是在FP32精度下进行的,这是为了保持数值稳定性。

由于FP16的数值范围较小,可能会导致梯度下溢出,所以GradScaler()在反向传播之前将梯度的值放大,然后在权重更新之后将放大的梯度缩放回来,在计算梯度后,使用scaler.step(optimizer)来应用缩放后的梯度,从而避免了数值下溢的问题。

torch.save(model.state_dict(), 'model.pth')
model.load_state_dict(torch.load('model.pth'))

在混合精度训练中,虽然模型的权重在训练过程中可能会被转换为 FP16 格式以节省内存和加速计算,但在保存模型时,我们通常会将权重转换回 FP32 格式。这是因为 FP32 提供了更高的数值精度和更广泛的硬件支持(FP16需要有Tensor Core的GPU),这使得模型在不同环境中的兼容性和可靠性更好。

混合精度训练有很多有意思的地方,目前使用动态混合精度的方法来充分利用GPU,以达到计算和内存的高效运行比是一个较为前沿的研究方向。

在这里插入图片描述

这篇关于代码与原理:混合精度训练详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1111706

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

hdu4407(容斥原理)

题意:给一串数字1,2,......n,两个操作:1、修改第k个数字,2、查询区间[l,r]中与n互质的数之和。 解题思路:咱一看,像线段树,但是如果用线段树做,那么每个区间一定要记录所有的素因子,这样会超内存。然后我就做不来了。后来看了题解,原来是用容斥原理来做的。还记得这道题目吗?求区间[1,r]中与p互质的数的个数,如果不会的话就先去做那题吧。现在这题是求区间[l,r]中与n互质的数的和

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

6.1.数据结构-c/c++堆详解下篇(堆排序,TopK问题)

上篇:6.1.数据结构-c/c++模拟实现堆上篇(向下,上调整算法,建堆,增删数据)-CSDN博客 本章重点 1.使用堆来完成堆排序 2.使用堆解决TopK问题 目录 一.堆排序 1.1 思路 1.2 代码 1.3 简单测试 二.TopK问题 2.1 思路(求最小): 2.2 C语言代码(手写堆) 2.3 C++代码(使用优先级队列 priority_queue)

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

K8S(Kubernetes)开源的容器编排平台安装步骤详解

K8S(Kubernetes)是一个开源的容器编排平台,用于自动化部署、扩展和管理容器化应用程序。以下是K8S容器编排平台的安装步骤、使用方式及特点的概述: 安装步骤: 安装Docker:K8S需要基于Docker来运行容器化应用程序。首先要在所有节点上安装Docker引擎。 安装Kubernetes Master:在集群中选择一台主机作为Master节点,安装K8S的控制平面组件,如AP

代码随想录冲冲冲 Day39 动态规划Part7

198. 打家劫舍 dp数组的意义是在第i位的时候偷的最大钱数是多少 如果nums的size为0 总价值当然就是0 如果nums的size为1 总价值是nums[0] 遍历顺序就是从小到大遍历 之后是递推公式 对于dp[i]的最大价值来说有两种可能 1.偷第i个 那么最大价值就是dp[i-2]+nums[i] 2.不偷第i个 那么价值就是dp[i-1] 之后取这两个的最大值就是d