【深度学习】谷歌用「钞」能力放大招:扩展到220亿参数的巨大视觉 Transformer...

本文主要是介绍【深度学习】谷歌用「钞」能力放大招:扩展到220亿参数的巨大视觉 Transformer...,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

作者丨科技猛兽    编辑丨极市平台

导读

 

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT- 22B,具有220亿参数。并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。 

本文目录

52 扩展到220亿参数的巨大视觉 Transformer
(来自谷歌,含 ViT 作者)
52 ViT-22B 论文解读
52.1 背景和动机
52.2 三句话概括 ViT-22B 模型的架构
52.3 ViT-22B 的实现方法:异步并行线性操作计算
52.4 数据集和超参
52.5 图像分类迁移性能
52.6 密集预测性能
52.7 ViT-22B 与人类感知的一致性

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT- 22B,具有220亿参数。并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。

52 扩展到220亿参数的巨大视觉 Transformer

论文名称:Scaling Vision Transformers to 22 Billion Parameters

论文地址:

https://arxiv.org/pdf/2302.05442.pdf

52.1 背景和动机

与自然语言处理类似,视觉预训练大模型也提高了在各种视觉任务的性能。这不仅依赖于更大的数据集,更大的可扩展的视觉架构,也同样依赖于新的训练策略。

但尽管如此,视觉模型从规模和效果上而言,远远落后于语言模型。迄今为止最大的密集视觉模型只有 4B 参数 ViT[1],而入门级的语言模型通常包含超过 10B [2][3]个参数以上,最大的密集语言模型有 540B 个参数。稀疏模型也展示出同样的趋势,其中语言模型的参数超过了一万亿[4],但最大的稀疏视觉模型只有约 15B[5]。

本文提出了迄今为止最大的密集视觉 ViT 模型 ViT-22B,并发现超大 ViT 病态训练的不稳定性,这种不稳定性组织了模型尺度的进一步扩展。作者通过仔细设计模型,以较高的效率实现模型并行训练。ViT-22B 的质量通过分类和下游任务实验进行评估,在这些任务中它达到或提高了当前的最先进水平。

通过多模态训练一个 text tower 来匹配视觉特征,ViT- 22B 在 ImageNet 上实现了 85.9% 的 zero-shot 精度。此外,该模型是一个很好的老师——用作蒸馏目标,作者训练了一个 ViT- B 学生模型,其在 ImageNet 上达到了88.6% 的 SOTA 精度。

除了分布的改进、ViT-22 的可靠性、不确定性估计和公平性都取得了提升。更重要的是,ViT-22 的特征更好地与人类的感知保持一致,实现了之前未见过的 87% 的形状偏差 (shape bias)。

52.2 三句话概括 ViT-22B 模型的架构

ViT-22B 是一种基于 Transformer 的模型,其架构的设计类似于原始 ViT,但包含以下3个修改,以提高效率和大规模训练的稳定性。如下图1所示,用三句话概括分别是:

  • 并行设计

  • Query 和 Key 的归一化

  • 省略 bias

1fe4f97035af73a9ad9bb97fb2398297.jpeg
图1:ViT-22B 模型的架构

第1句话,把 ViT 中的 Self-Attention 层换成了一个 Self-Attention 层加一个 MLP 层的并行设计,公式如下:

bff602dead9232eacedb264a25e03f0a.png

注意这里作者不是简单地将两个模块相加,而是使用了一个并行化技巧,即:用于 Self-attention 中的 Query,Key,Value 计算的矩阵乘法和 MLP 的第1个线性层被融合到一个单独的操作中;用于 Self-attention 中的输出投影和 MLP 的第2个线性层也被融合到一个单独的操作中。这种方法最初是由 PaLM[6] 提出的,该技术在不降低性能的情况下将最大模型的训练速度提高了 15%。

第2句话,Self-Attention 中 Query 和 Key 的计算过程添加了归一化。在以往 ViT 扩展的工作中,作者观察到了在训练了几千步之后,training loss 发散了的情况,这毫无疑问使得我们无法训练大型 ViT 模型,尤其是在 8B 参数左右的模型中观察到这种不稳定性。

出现这种现象的原因是:注意力矩阵的值异常 (近似于 one-hot,有的地方注意力很大,其他位置几乎为零) 引起的,这导致注意力权重的熵接近于零。为了解决这个问题,作者采用[7]的方法,将 LayerNorm 应用于 Self-Attention 中 Query 和 Key 的计算过程。具体可以写成:

式中, 是 query/key 的维度, 是 layer normalization, 分别是 Query 和 Key 的权重矩阵。对 8B 参数模型的影响如图2所示,其中归一化防止了注意力矩阵的值不受控的异常而导致的训练发散。

cd11720e3929ecb94953d707f8d9b1da.jpeg
图2:Self-Attention 中 Query 和 Key 的计算过程添加归一化对 8B 参数模型的影响

第3句话,省略 bias。遵循 PaLM 的做法,从 QKV 投影中去除 bias,并使用没有 bias 项和 centering 的 Layer Normalization[8]。这在不影响训练质量的前提下提升了训练速度。但是,与 PaLM 不同的是,作者对所有 MLP 层使用了 bias 项,因为观察到质量得到了改善,且训练速度没有下降。

ViT-22B 使用 14×14 大小的 Patch,输入图像分辨率为 224×224。类似于原始 ViT,ViT-22B 也采用了可学习的位置编码。在对高分辨率图像 (不同数量的 Patch) 进行微调期间,也对预训练的位置编码执行二维插值。ViT-22B和 ViT-G,ViT-e 的超参数对比如下图3所示。

fe946626760f1004352d815e85133ab5.png
图3:ViT-22B和 ViT-G,ViT-e 的超参数对比

ViT-22B 的 model card 如下图所示。

a90fd229c32a6164241a16bb9215fc4f.jpeg a77f8182821462bc2d0b05d7ab6191f2.jpeg
图4:ViT-22B 的 model card

52.3 ViT-22B 的实现方法:异步并行线性操作计算

ViT-22B 基于 JAX 框架和 FLAX,Scenic 库,它同时利用了模型和数据的并行性。作者使用了 jax.xmap 这个 API,其对所有中间体的分片 (例如权重和激活) 以及芯片间通信提供了明确的控制。

这个高效的实现如下图5和6所示。该怎么理解这个过程呢?

f40245fea3eb817e0789598a9dc4b7ee.jpeg
图5:矩阵 A 在不同设备之间按行切分
f316f936dbc8e1f85bee6b17ff355694.jpeg
图6:矩阵 A 在不同设备之间按列切分

为了说明这个过程, 这一段我们介绍下它的原理。考虑计算 这个矩阵乘法运算。

比如说我们有 台设备, 设 分别是 的第 个 Block , 那么 分别放在第 台设备上, 如上图5和图6所示。设 。那么现在等于说我们有 个矩阵 的分块。要把这么多块均匀地放在 台 devices 上面, 就 有两种放法。为了方便说明, 这里我们假设 吧。那么现在矩阵 就被分成了 的分块。

  • 第1种: 的每一行的块放在同一台 device 上, 即: 放在同一台上面, 对应图5。

  • 第2种: 的每一列的块放在同一台 device 上, 即: 放在同一台上面, 对应图6。

对于第1种情况, 计算 需要 次 的通信, 因为 的维度是 , 所以就需要 float 的通信。

对于第2种情况, 计算 需要 次 中间值的通信, 因为 中间值的维度是 , 所以就 需要 float 的通信。

当 时, 采用哪一种计算方式都可以。当 时, 比如 MLP 的输出层有 , 那么这个时候采用第2种情况的计算方法。

注意, 以上通信的过程与计算过程是异步的。即, 当计算一层时, 设备可以开始通信下一层的权重, 从而最大限度地减少通信开销。

作者将芯片组织成大小为 的 逻辑网格,其中 是数据轴的大小, 是模型轴的大小。然后, 对于 组中的每个组, 个设备获得相同 Batch 的图像, 每个设备只保留 的激活值, 并负责计算输出 的 。

52.4 数据集和超参

ViT-22B 在 JFT的一个版本上训练,训练集包含大约 4B 图像,并采用 Sigmoid 交叉熵损失以多标签分类方式使用所有分配的标签。

ViT-22B 把图片分为 14×14 大小的块,使用 65k 的 Batch Size 训练 177k steps (3 Epochs),初始学习率,reciprocal square-root learning rate schedule,以及 10k 的线性学习率 warmup,和 30k 的的线性学习率 cooldown。上游预训练的权重衰减 head 设为3.0,body 设为 0.03。

52.5 图像分类迁移性能

Linear Probing 实验结果

作者在 ImageNet 上使用了10个周期的动量 SGD,分辨率为 224px,使用 mild random cropping 和 horizontal flipping 作为唯一的数据增强策略,而没有进一步的正则化措施。

如下图7所示是 ViT-22B 的 Linear Probing 实验结果,虽然收益不高,但在这个尺度上仍有显著的改善。而且图7还说明,,像 ViT-22B 这样的大模型的 Linear Probing 可以接近或超过具有小模型的高分辨率完全微调的性能,而 Linear Probing 通常成本更低。

4ebefc3da9472975e0f569437810ec10.jpeg
图7:ImageNet 上 ViT-22B 的 Linear Probing 实验结果

作者进一步在细粒度分类数据集 iNaturalist 2017 上测试线性可分性。iNaturalist 2017 有5,089个细粒度类别,属于13个大类。与 ImageNet 不同,不同类别的图像数量是不平衡的。概念的长尾分布对分类更具挑战性。作者将 ViT- 22B 与其他 ViT 变体进行比较,并测试了 224px 和 384px 的输入分辨率。结果如图8所示,可以观察到 ViT- 22B 明显优于其他变体,特别是在标准的 224px 输入分辨率下,这表明 ViT-22B 中大量的参数对于从图像中提取详细信息是有用的。

43009eedbf586cbf7463bee80efef8a7.png
图8:iNaturalist 2017 上 ViT-22B 的 Linear Probing 实验结果

Out-of-distribution 实验结果

定义: 在分类任务中,给定测试图片,若模型在训练阶段模型见过或类似的图片,则能正确分类;但如果与训练集完全不相关,也会被强制判定为训练集类别中的一种,这种情况是不合理的。OOD 算法希望能判断的分布状况是否与训练集一致,若一致,则称为 in-distribution (ID),否则称为 out-of-distribution (OOD)。
使用场景举例: 在 MNIST 上训练的一个分类模型,然后,输入一张“马”的图片,会被归类为数字 0~9,这是错误的。此时,MNIST 数据集就是 in-distribution,相对于 ID 而言,“马”是 out-of-distribution。

作者构建了从 JFT 到 ImageNet 的标签映射,以及从 ImageNet 到不同分布外数据集的标签映射,即 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A。ImageNet-R 和 ImageNet-A 使用相同的 ImageNet 的200个标签子空间,而 ObjectNet 有313个类别,其中只考虑与 ImageNet 标签空间重叠的113个类别。那么这样以后,JFT 上面训练出的模型的预测结果就可以转化到其他数据集上面了,这就为模型的 Out-of-distribution 能力提供了一种验证手段。

Out-of-distribution 的评估一般是首先在一个较大的数据集上面 (比如 ImageNet) 做预训练,然后在 ImageNet-R ,ImageNet-A 等数据集上直接评估其性能。实验结果如下图9所示。作者做了两种实验,其一 (图9上半部分) 是首先将 ViT-22B 模型在 JFT 数据集上做预训练,然后在 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A 数据集上评估性能。其二 (图9下半部分) 是首先将 ViT-22B 模型在 JFT 数据集上做预训练,然后在 ImageNet 数据集上微调,最后在 ObjectNet ,ImageNet-v2,ImageNet-R 和 ImageNet-A 数据集上评估性能。

933090e0d9dd9eef05ce572f3b79f74a.jpeg
图9:OOD 实验结果。标注 "ema" 的模型使用 Polyak 平均进行微调

把图9中 ObjectNet 的实验结果画出来就是图10.

ce28d63311a6fcc400d1b2e2ba54bbf0.jpeg
图10:OOD 中 ObjectNet 实验结果

从图9和10中可以得出结论:把模型做大可以增加 Out-of-distribution 的性能。这适用于只看过 JFT 图像的 ViT-22B 模型,以及在 ImageNet 上做了微调过之后的模型。在这两种情况下,ViT-22B 在更大的模型上都延续了 OOD 性能更好的趋势。即使 ImageNet 的性能饱和,但从图10中也可以看到 ObjectNet 上的精度从 ViT-e/14 到 ViT-22B 的显著提升。

52.6 密集预测性能

密集预测任务的迁移性能也是评价一个 Backbone 模型的关键因素。作者通过语义分割和单目估计任务评价 ViT-22B 捕获的几何和空间信息的质量。

语义分割任务使用 ViT-22B 作为骨干模型,UperNet 作为分类头,在 ADE20K, Pascal Context 和 Pascal VOC 三个数据集上进行评测。如下图11所示,作者将 ViT-22B 与 DeiT-III ,ViT-G 进行了比较,且只使用了一部分的训练数据。作者使用线性解码器和端到端微调。从图11中可以观察到,当只使用少量的数据时,ViT-22B 骨干更好。例如,当只对1200张图像 (即1/16) 的 ADE20K 训练数据进行微调时,ViT-22B 达到了 44.7 mIoU 的性能,比 DeiT-III Large 提高了8.6 mIoU,比 ViT-G 提高了 2.3 mIoU。当数据量较大时,ViT-G 和 ViT-22B 的性能趋于一致。

50dbbec8b5f02ebf709818659c202357.jpeg
图11:ADE20K Fewshot 语义分割实验结果

单目估计任务使用 ViT-22B 作为骨干模型,Dense Prediction Transformer (DPT) 作为单目估计头,或者仅仅使用一个简单的线性解码器作为估计头。实验在 Waymo Open real-world driving 数据集上进行评测。如下图12所示,上半部分 (DPT解码器) 的结果可以观察到,与不同的主干相比,使用 ViT-22B 骨干网络得到了最好的性能。通过将 ViT-22B 骨干与 ViT-e 进行比较,发现扩展架构可以提高性能。使用线性解码器,可再次观察到使用 ViT-22B 骨干模型可获得最佳性能。DPT 和线性解码器之间的差距表明,虽然在 ViT 特征中保留了足够的几何信息,但只有一部分可被普通的线性解码器利用到。

b78f8372ff8fd5394259be5e31b201f2.jpeg
图12:Waymo Open dataset 单目估计实验结果

在研究规模的影响时,除了下游任务性能之外,还有一些重要的方面需要考虑。比如一个主干网络的与人类感知的一致性。

52.7 ViT-22B 与人类感知的一致性

ViT-22B 分类决策与人类分类决策的一致性如何?通过这篇论文 "Partial success in closing the gap between human and machine vision" 的方法,作者评估了在 ImageNet 上以不同分辨率 (224,384,560) 微调的三个 ViT-22B 模型。在所有指标中,ViT-22B-224 具有最高的 OOD 稳健性 (图13 (a)), ViT-22B -384 与人类分类精度最接近(图13(b)), ViT-22B-560 具有最大的错误一致性 (即大多数类似人类的错误模式,图13(d))。ViT-22B 模型在视觉模型中有最高的 shape bias 记录:而大多数模型都有很强的 texture bias (20-30% 的 shape bias + 70-80% 的 texture bias),人类的 shape bias 为 96% + 4%,而 ViT-22B -384 达到了前所未见的 87% shape bias + 13% texture bias。总体而言,ViT-22B显著改善了人类视觉物体识别的对齐。

225f9b2e809076685fd5ee60e79a3514.jpeg
图13:ViT-22B 与人类感知的一致性

除此之外,作者还对 ViT-22B 的公平性,鲁棒性,可靠性和校准。进行了相关实验,详细细节读者可以参考原始论文。作者发现,随着模型尺寸的增加,会出现有利的特性。

总结

本文提出了 ViT-22B,目前最大的视觉 Transformer 模型,有220亿个参数。作者证明,通过对原始架构进行三点修改,可以实现出色的硬件利用率训练稳定性,从而产生一个在几个基准上实现 SOTA 的模型。作者的评估进一步表明,与现有模型相比,ViT-22B 在形状和纹理偏差方面更符合人类,并在公平性和稳健性方面具有优势。

参考:

基于深度模型Out of Distribution(OOD)基础技术路线研究:https://blog.csdn.net/Aqrose_666/article/details/124592372

参考

  1. ^PaLI: A Jointly-Scaled Multilingual Language-Image Model

  2. ^Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer

  3. ^Unifying language learning paradigms

  4. ^Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

  5. ^Scaling Vision with Sparse Mixture of Experts

  6. ^PaLM: Scaling language modeling with pathways

  7. ^Intriguing Properties of Transformer Training Instabilities

  8. ^Root Mean Square Layer Normalization

 

648bb736de838276eebc19166f5cba6a.jpeg

 
 
 
 
往期精彩回顾适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码

这篇关于【深度学习】谷歌用「钞」能力放大招:扩展到220亿参数的巨大视觉 Transformer...的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/308771

相关文章

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

详解Spring Boot接收参数的19种方式

《详解SpringBoot接收参数的19种方式》SpringBoot提供了多种注解来接收不同类型的参数,本文给大家介绍SpringBoot接收参数的19种方式,感兴趣的朋友跟随小编一起看看吧... 目录SpringBoot接受参数相关@PathVariable注解@RequestHeader注解@Reque

Java向kettle8.0传递参数的方式总结

《Java向kettle8.0传递参数的方式总结》介绍了如何在Kettle中传递参数到转换和作业中,包括设置全局properties、使用TransMeta和JobMeta的parameterValu... 目录1.传递参数到转换中2.传递参数到作业中总结1.传递参数到转换中1.1. 通过设置Trans的

java如何调用kettle设置变量和参数

《java如何调用kettle设置变量和参数》文章简要介绍了如何在Java中调用Kettle,并重点讨论了变量和参数的区别,以及在Java代码中如何正确设置和使用这些变量,避免覆盖Kettle中已设置... 目录Java调用kettle设置变量和参数java代码中变量会覆盖kettle里面设置的变量总结ja

spring 参数校验Validation示例详解

《spring参数校验Validation示例详解》Spring提供了Validation工具类来实现对客户端传来的请求参数的有效校验,本文给大家介绍spring参数校验Validation示例详... 目录前言一、Validation常见的校验注解二、Validation的简单应用三、分组校验四、自定义校

SpringBoot中Get请求和POST请求接收参数示例详解

《SpringBoot中Get请求和POST请求接收参数示例详解》文章详细介绍了SpringBoot中Get请求和POST请求的参数接收方式,包括方法形参接收参数、实体类接收参数、HttpServle... 目录1、Get请求1.1 方法形参接收参数 这种方式一般适用参数比较少的情况,并且前后端参数名称必须

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;