量化训练之补偿STE:DSQ和QuantNoise

2023-11-23 18:50

本文主要是介绍量化训练之补偿STE:DSQ和QuantNoise,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

(本文首发于公众号,没事来逛逛)

今天讲一点量化训练中关于 STE (Straight Through Estimator) 的问题,同时介绍两种应对问题的方法:DSQ 和 QuantNoise。分别对应两篇论文:Differentiable Soft Quantization: Bridging Full-Precision and Low-Bit Neural Networks 和 Training with Quantization Noise for Extreme Model Compression。

阅读本文需要对量化训练的过程有基本了解,可以参考我之前的这篇文章。

STE的问题

在量化训练中,由于 round 函数的存在,我们无法正常求导,因此退而求其次,在反向传播的时候用 STE 跳过了这个函数。这个「跳过」,就是把 STE 的导数默认为 1。

但这种做法有个副作用,由于它无法反应真实的量化误差,所以,不管量化位数有多少 (8 比特、4 比特等等),导数都是一样的。

看下面这个例子:

class QuantConv(nn.Module):def __init__(self, conv_module, bits=8):super(QuantConv, self).__init__()self.conv_module = conv_moduleself.bits = bitsdef forward(self, x):scale, zero_point = calcScaleZeroPoint(self.conv_module.weight.data.min(), \self.conv_module.weight.data.max(), num_bits=self.bits)weight, bias = self.conv_module.weight, self.conv_module.bias# 对weight做伪量化,模拟量化误差quant_weight = dequantize_tensor(quantize_tensor(weight, scale, zero_point, self.bits), scale, zero_point)# detach这一步就是STEreturn F.conv2d(x, weight + (quant_weight - weight).detach(), bias, 3, 1)

我定义了一个量化的卷积 QuantConv,对 weight 做了伪量化,其中 calcScaleZeroPointquantize_tensordequantize_tensor 这几个函数的定义可以在之前的文章中找到。

然后,我们用不同的比特数来量化,看看在 BP 的时候,梯度有什么差别:

conv = nn.Conv2d(1, 1, 3, 1)
x = torch.randn((1, 1, 4, 4))  # 使用同一个输入
quantconv = QuantConv(conv)a = quantconv(x).sum().backward()   # BP计算梯度
print("use 8 bit")
print(quantconv.conv_module.weight.grad)quantconv.zero_grad()
quantconv.bits = 2
a = quantconv(x).sum().backward()    # BP计算梯度
print("use 2 bit")
print(quantconv.conv_module.weight.grad)

输出结果如下:

use 8 bit
tensor([[[[ 0.6101, -2.7252, -0.2428],[ 2.2399,  0.5673,  1.7511],[-0.5968,  1.2209,  0.6866]]]])
use 2 bit
tensor([[[[ 0.6101, -2.7252, -0.2428],[ 2.2399,  0.5673,  1.7511],[-0.5968,  1.2209,  0.6866]]]])

可以发现,对同一个输入,用同样的损失函数计算梯度,不同比特数量化得到的梯度是一样的!但不同比特数带来的量化误差明显有很大差异,This is unreasonable!

当然,这个例子的 loss 比较取巧,如果用其他 loss (比如交叉熵函数),可能梯度就不会一样了。但不管是哪种 loss,到 STE 这一步就仿佛一套组合拳打在棉花上,最重要的梯度信息都扔掉了。这里面的原因就在于 STE 根本无法体现量化的损失。在低比特量化的时候,这种副作用尤其明显 (所以 QAT 在低比特训练中尤其困难,模型权重根本训不动)。

DSQ

基本思想

为了解决这个问题,一个很直接的想法是用某个可导的函数来近似 round,从而避免使用 STE。

比如说,我们知道傅立叶级数可以近似任何周期函数:
请添加图片描述
(图片摘自:https://www.zhihu.com/search?q=%E5%82%85%E7%AB%8B%E5%8F%B6%E5%8F%98%E6%8D%A2%E4%B9%8B%E6%8E%90%E6%AD%BB%E6%95%99%E7%A8%8B&utm_content=search_suggestion&type=content)

如果把 round 当成一个周期函数,那我们就可以用傅立叶级数来逼近 round 了,而傅立叶级数是可以求导的。

或者,我们也可以对 round 函数进行泰勒展开,用多项式来近似。

又或者,我们知道神经网络本身可以模拟任何函数,因此甚至可以用一个神经网络来近似 round。

不过,以上这些想法都过于复杂,计算量巨大,操作起来比较困难。

而 DSQ 做的就是引入一个相对简单的函数来模拟 round,做到计算简单,同时尽可能逼近 round 函数。

这个函数是这样定义的:
ϕ ( x ) = s tanh ⁡ ( k ( x − m i ) ) , i f x ∈ P i (1) \phi(x)=s\tanh(k(x-m_i)), \quad if\quad x\in P_i \tag{1} ϕ(x)=stanh(k(xmi)),ifxPi(1)
其中 m i = l + ( i + 0.5 ) Δ m_i=l+(i+0.5)\Delta mi=l+(i+0.5)Δ s = 1 tanh ⁡ ( 0.5 k Δ ) s=\frac{1}{\tanh(0.5k\Delta)} s=tanh(0.5kΔ)1

这里面 l l l 是实数域上的最小值, Δ \Delta Δ 是量化对应的每段间隔长度, P i P_i Pi 对应第 i i i 个间隔。

这个函数比较复杂,大家不用过于纠结这里面的细节,你只需要知道它长这个样子就可以了:
请添加图片描述
这是用 1 比特分别量化 [0, 1] 和 [-1, 1] 这两个区间时得到的函数图像。由于只使用了 1 个 bit,所以 Δ \Delta Δ 就是整个区间的长度。

当然,也可以用更多的比特进行量化 (比如 2 bit):
请添加图片描述
有读者可能发现,这个 ϕ ( x ) \phi(x) ϕ(x) 函数和 tanh 有点像,它会把每个间隔内的数值映射到 [-1, 1] 的范围内。我们可以再看看不同 k 对 ϕ ( x ) \phi(x) ϕ(x) 的影响:

请添加图片描述
请添加图片描述

结论就是:k 越大,这个函数和 round 越接近。

有了这个函数后,论文提出了一个 soft 的伪量化方式:
KaTeX parse error: Undefined control sequence: \notag at position 37: … l & x < l \\ \̲n̲o̲t̲a̲g̲ ̲u & x >u \\ \ta…
同样地,我们看看这个函数和普通伪量化的差别:
请添加图片描述
请添加图片描述
上面分别是用 1 比特和 2 比特量化的结果,绿线是普通 round 函数的伪量化,红线则是 DSQ 的伪量化。

随着 k 增大,DSQ 和一般的伪量化越来越接近,而 DSQ 由于可导,还能近似模拟 round 的梯度。因此,在量化训练的时候,我们可以直接把伪量化换成 DSQ 函数。

不过,虽然 DSQ 能近似 round 的伪量化,但没法百分百一样,因此,需要用一些措施让网络在训练的时候可以感知到这部分误差,并最终让这部分误差尽可能小,这样,DSQ 才能成为真正可导的伪量化函数。

为此,论文的做法是引入一个 α \alpha α 来衡量 DSQ 和 round 函数之间的误差:
α = 1 − tanh ⁡ ( 0.5 k Δ ) = 1 − 1 s (3) \alpha=1-\tanh(0.5k\Delta)=1-\frac{1}{s} \tag{3} α=1tanh(0.5kΔ)=1s1(3)
并由此推出:
k = 1 Δ l o g ( 2 α − 1 ) (4) k=\frac{1}{\Delta}log(\frac{2}{\alpha}-1) \tag{4} k=Δ1log(α21)(4)
公式 (3)(4) 我冥思苦想了很久,始终想不透是怎么推出来的。后来询问了论文作者昊哥,结果昊哥说,DSQ 比较水,让我不要花太多时间,量化训练直接用 LSQ 算法就可以 (意思就是 DSQ 效果可能并没有那么优秀。。。囧)。因此这部分我就没再花精力去推导了,有看懂的小伙伴还请不吝赐教。

总之,有了 α \alpha α 后,我们就可以度量 DSQ 和真正量化引起的误差了,只要让 α \alpha α 越小,DSQ 就越准确。因此,论文干脆用 α \alpha α 来重新表示 DSQ,把 (1)(3) 结合一下就可以得到:
ϕ ( x ) = 1 1 − α tanh ⁡ ( k ( x − m i ) ) , i f x ∈ P i (5) \phi(x)=\frac{1}{1-\alpha}\tanh(k(x-m_i)), \quad if\quad x\in P_i \tag{5} ϕ(x)=1α1tanh(k(xmi)),ifxPi(5)
再把 (5) 代入 (2) 后得到另一种形式的 DSQ,此时的 DSQ 中就只有 α \alpha α 这个变量了。

然后,在训练网络的时候,除了原本的损失函数,我们还需要对 α \alpha α 施加约束,让 α \alpha α 越小越好:
min ⁡ α L ( α ; x , y ) s . t . ∣ ∣ α ∣ ∣ 2 < λ (6) \underset{\alpha} {\operatorname {min}} L(\alpha;x, y) \tag{6} \\ s.t. ||\alpha||_2<\lambda αminL(α;x,y)s.t.α2<λ(6)
整个训练过程中,DSQ 随着 α \alpha α 越来越小,和真正的量化函数会越来越接近,同时,由于 DSQ 本身可以求导,我们也可以近似地求出量化函数的梯度,并作用到网络参数上。

此外,论文还对 clip 的边界也进行学习,不过这里只介绍 DSQ 的核心思想,就不详细讲解了。

训练完成后,我们可以按照普通量化的方式做量化推理即可。

代码实现

论文作者也开源了这个算法,具体链接请参考https://github.com/TheGreatCold/MQBench/blob/master/mqbench/fake_quantize/dsq.py。

这里多提一句,作者昊哥曾经开发了公司内部第一代量化训练框架 dirichlet,我之前也是通过阅读他们的代码学习到网络量化是怎么回事。后来由于 pytorch 发布了 FX,现在他们基于此开发了第二套工具,所幸的是这套工具是开源的,因此之后想针对这套新工具,讲一下一个量化训练框架该如何搭建。美中不足的是,这套工具需要 pytorch1.8 才能使用,而据我所知,很多小伙伴因为设备的原因,很难更新到这一版本。。。

说回原文,我们来看 DSQ 的核心实现:

def dsq_function_per_tensor(x, scale, zero_point, quant_min, quant_max, alpha):tanh_scale = 1 / (1 - alpha)tanh_k = math.log((tanh_scale + 1) / (tanh_scale - 1))x = x / scale + zero_pointx = torch.clamp(x, quant_min, quant_max)x = x.floor() + (tanh_scale * torch.tanh(tanh_k * (x - x.floor() - 0.5))) * 0.5 + 0.5x = (x.round() - x).detach() + x  # detach模拟STEx = (x - zero_point) * scalereturn x

核心代码只有寥寥几句,对应的是上文 DSQ 的公式 (2)。

其中,最核心的是这句代码:

x = x.floor() + (tanh_scale * torch.tanh(tanh_k * (x - x.floor() - 0.5))) * 0.5 + 0.5

这一步就是在计算 ϕ ( x ) \phi(x) ϕ(x) D S ( x ) D_S(x) DS(x)。不过论文里的公式是在浮点域上处理的,而代码里是先转换到整型域再处理。大家也不要太纠结为什么代码可以这样处理,这篇论文在公式表达上并不是很清晰,我们学习它的思想就可以,至于具体方法和细节,昊哥说了,不建议花太多时间 (此处甩锅。。)

另外,代码有一个 detach 来模仿 STE 的操作,有人可能要说了,不是说好 DSQ 可以求导吗,怎么又要用 STE?

我们来看下普通伪量化和 DSQ 的差别:
请添加图片描述
在普通的伪量化中,我们先经过线性量化后,再 round 变换到绿线:

x=round(clip(x/S-Z, q_min, q_max))

而在 DSQ 中,我们则是先线性量化加 DSQ 后,再 round 变换到绿线 (由于 DSQ 不可能和实际的量化函数一致,因此我们还是需要加上 round 操作保证最终结果是一致的):

x=round(clip(DSQ(x/S-Z), q_min, q_max))

前者在求导时,round 引起的巨大误差直接被跳过了,而后者由于有 DSQ 的存在,我们在 round 前就已经非常接近量化函数 (绿线) 的位置了,而 DSQ 是可导的,因此求出的导数更接近 round 的误差,这样网络学习起来就更准确了。

QuantNoise

前面说了一大堆,可以看出 DSQ 为了近似 round 函数,还是引入了很多复杂的操作。而 QuantNoise 就简单了,它用了另一种取巧的方式来弥补 STE 带来的损失。

既然 STE 无法反应量化函数的导数,那我们就在量化的时候,不要把所有参数都量化,而是随机量化一部分,另一部分还是保持全精度,这样在做伪量化的时候,信息损失不至于太大,反向传播的时候,部分权重也能以正常求导的方式适应其他量化权重引起的损失。

这个过程和 Dropout,以及之前的论文 Incremental Network 非常相似:
请添加图片描述
代码实现上也比 DSQ 简单得多:

noise = (quantize(w) - w) * mask
w_q = w + noise.detach()

和普通 QAT 相比,这里只是多了一个 mask。

论文给出的实验证明,在低比特量化的时候,这种方式效果很明显。不过遗憾的是,我自己在去噪、GAN 等任务中并没有发现这种方法有什么提升~囧~

总结

这篇文章主要介绍了量化训练中,STE 对训练本身带来的问题,并介绍了两种解决问题的思路:DSQ 和 QuantNoise。

其中 DSQ 从问题本质 (round 函数不可导) 出发,引入一个可导的函数来近似 round。而 QuantNoise 虽然没有直面 STE 的问题,但用一种取巧的方式在 round 函数上「钻了个洞」,让一部分权重可以无损地通过,用这部分权重弥补 STE 带来的梯度损失。

从论文标题也可以看出,这两种方法主要针对低比特网络。因为 STE 在低比特训练时,副作用尤其明显,毕竟比特数越低,round 带来的量化损失越明显。

当然,大家不必把这两篇论文奉为圭臬,介绍它们只是给大家提供一些思路,在某些任务中,论文本身的方法未必奏效。但这些思路可以打开我们的视野,兴许哪天你受启发就找到更优秀的方法了。

参考

  • Differentiable Soft Quantization: Bridging Full-Precision and Low-Bit Neural Networks
  • Training with Quantization Noise for Extreme Model Compression
  • INCREMENTAL NETWORK QUANTIZATION: TOWARDS LOSSLESS CNNS WITH LOW-PRECISION WEIGHTS
  • https://www.yuque.com/yahei/hey-yahei/quantization-retrain_improved_qat
  • https://github.com/TheGreatCold/MQBench/tree/master/mqbench

欢迎关注我的公众号:大白话AI,立志用大白话讲懂AI。

这篇关于量化训练之补偿STE:DSQ和QuantNoise的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/420024

相关文章

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录 在深度学习项目中,目标检测是一项重要的任务。本文将详细介绍如何使用Detectron2进行目标检测模型的复现训练,涵盖训练数据准备、训练命令、训练日志分析、训练指标以及训练输出目录的各个文件及其作用。特别地,我们将演示在训练过程中出现中断后,如何使用 resume 功能继续训练,并将我们复现的模型与Model Zoo中的

多云架构下大模型训练的存储稳定性探索

一、多云架构与大模型训练的融合 (一)多云架构的优势与挑战 多云架构为大模型训练带来了诸多优势。首先,资源灵活性显著提高,不同的云平台可以提供不同类型的计算资源和存储服务,满足大模型训练在不同阶段的需求。例如,某些云平台可能在 GPU 计算资源上具有优势,而另一些则在存储成本或性能上表现出色,企业可以根据实际情况进行选择和组合。其次,扩展性得以增强,当大模型的规模不断扩大时,单一云平

神经网络训练不起来怎么办(零)| General Guidance

摘要:模型性能不理想时,如何判断 Model Bias, Optimization, Overfitting 等问题,并以此着手优化模型。在这个分析过程中,我们可以对Function Set,模型弹性有直观的理解。关键词:模型性能,Model Bias, Optimization, Overfitting。 零,领域背景 如果我们的模型表现较差,那么我们往往需要根据 Training l

如何创建训练数据集

在 HuggingFace 上创建数据集非常方便,创建完成之后,通过 API 可以方便的下载并使用数据集,在 Google Colab 上进行模型调优,下载数据集速度非常快,本文通过 Dataset 库创建一个简单的训练数据集。 首先安装数据集依赖 HuggingFace datasetshuggingface_hub 创建数据集 替换为自己的 HuggingFace API key

【YOLO 系列】基于YOLOV8的智能花卉分类检测系统【python源码+Pyqt5界面+数据集+训练代码】

前言: 花朵作为自然界中的重要组成部分,不仅在生态学上具有重要意义,也在园艺、农业以及艺术领域中占有一席之地。随着图像识别技术的发展,自动化的花朵分类对于植物研究、生物多样性保护以及园艺爱好者来说变得越发重要。为了提高花朵分类的效率和准确性,我们启动了基于YOLO V8的花朵分类智能识别系统项目。该项目利用深度学习技术,通过分析花朵图像,自动识别并分类不同种类的花朵,为用户提供一个高效的花朵识别

深度学习与大模型第3课:线性回归模型的构建与训练

文章目录 使用Python实现线性回归:从基础到scikit-learn1. 环境准备2. 数据准备和可视化3. 使用numpy实现线性回归4. 使用模型进行预测5. 可视化预测结果6. 使用scikit-learn实现线性回归7. 梯度下降法8. 随机梯度下降和小批量梯度下降9. 比较不同的梯度下降方法总结 使用Python实现线性回归:从基础到scikit-learn 线性

使用openpose caffe源码框架训练车辆模型常见错误及解决办法

错误1:what():  Error: mSources.size() != mProbabilities.size() at 51, OPDataLayer, src/caffe/openpose/layers/oPDataLayer.cpp 原因:这是因为在网络模型中数据源sources和probabilities设置的参数个数不一样导致的,一个数据源对应一个概率 解决方法:只需要将网络文