对比自监督学习浪潮迅猛来袭,你准备好了吗?

2024-04-13 21:48

本文主要是介绍对比自监督学习浪潮迅猛来袭,你准备好了吗?,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

来源:AI科技评论

本文约5800字,建议阅读10分钟

PyTorch Lightning 创始人的对比学习综述,对自监督学习、对比学习等进行了简要回顾。


近年来,自监督学习逐渐成为了备受人们关注的应对标注缺乏问题的热门解决方案,科研人员在基于对比学习的自监督方法方面进行了大量研究。本文是 PyTorch Lightning 创始人 William Falcon 的对比学习综述,对自监督学习、对比学习等基础概念进行了简要的回顾,介绍了 CPC、AMDIM、BYOL、SimCLR、Swav等最近较为著名的对比学习方法,并提出了一种名为 YADIM 的新型对比学习算法。

图 1:对比学习

本文介绍了论文"A Framework For Contrastive Self-SupervisedLearning And Designing A New Approach "中的主要内容。

论文地址:

https://arxiv.org/abs/2009.00104

在过去的一年中,一类“新颖”的自监督学习(AMDIM、CPC、SimCLR、BYOL、Swav等)算法在人工智能研究领域取得了多项目前最优的结果。

在我们近期发表的论文“A Framework For ContrastiveSelf-Supervised Learning And Designing A New Approach”中,我们为描述对比自监督学方法形式化定义了一个概念框架。并使用该框架分析了三种对比学习的示例:SimCLR、CPC、AMDIM,表明:尽管这些方法似乎在表面上看起来各不相同,但事实上它们都只是对彼此做出了细微的调整。

在本文中,我们将:

  • 回顾自监督学习; 

  • 回顾对比学习;

  • 提出一种比较近期各种对比学习方法的框架;

  • 使用我们的框架比较CPC、AMDIM、MOCO、SimCLR、BYOL;

  • 使用我们的框架形式化定义了一种新的方法“YADIM”;

  • 描述了一些我们的实验结果;

  • 描述了取得这些实验结果的计算要求。

1、实现

读者可以通过以下链接获取使用 PyTorch LIghtning框架编写的所有文中介绍的方法,从而在任意的硬件设备上训练这些算法,并且更加容易进行对比。

  • AMDIM:

    https://pytorchlightningbolts.readthedocs.io/en/latest/self_supervised_models.html#amdim

  • BYOL:

    https://pytorchlightningbolts.readthedocs.io/en/latest/self_supervised_models.html#byol

  • CPC V2:

    https://pytorchlightningbolts.readthedocs.io/en/latest/self_supervised_models.html#cpc-v2

  • Moco V2:

    https://pytorchlightningbolts.readthedocs.io/en/latest/self_supervised_models.html#moco-v2

  • SimCLR:

    https://pytorchlightningbolts.readthedocs.io/en/latest/self_supervised_models.html#simclr

2、自监督学习

回想一下,在监督学习任务中,我们会为系统给定一个输入 x 和一个相应的标签 y。

图 2:监督式学习——图左侧为输入图像,右侧为标签。

在自监督学习任务中,我们仅仅为系统给定输入 x,而不给定标签 y,系统需要“学会根据输入的某部分来预测输入中的其它部分”。

图 3:在自监督学习中,输入数据被同时用作源和目标。

事实上,这种形式化定义是非常通用的,你可以创造性地对输入进行“分割”。这种策略被称为“前置任务”(又称“代理任务”),研究者们已经尝试了各种各样的方法。在这里,我们给出三种示例:

  • 预测两个图块之间的相对位置

    (https://arxiv.org/abs/1505.05192)

  • 解决一个拼图问题

    (https://arxiv.org/abs/1603.09246)

  • 对某张图像进行着色

    (https://richzhang.github.io/colorization/)

图 4:前置任务示例

尽管上述方法极具创造性,但是它们实际上效果并不理想。然而,最近一系列使用“对比学习”的方法已经开始显著地缩小在 ImageNet 数据集上与监督式学习之间的性能差距。

 

图 5:最新的方法(Swav)正在缩小与在 ImageNet 上训练的监督式方法的差距。

3对比学习

大多数机器学习算法背后的基本思想是,相似的样本应该被划分到一起,而与其它相关示例的聚类簇相距较远。

Chopra 等人于 2014 年发表的最早的有关对比学习的工作“Learning a SimilarityMetric Discriminatively, with Application to Face Verification”正是基于这一思想构建的,主要思想的示意图如下:

图 6:对比学习示例

对比学习通过使用三个关键的元素(正样本、anchor、负样本的表征)来实现上述思想。为了创建一个正样本对,我们需要两个相似的样本,而当我们创建一个负样本对时,我们将使用第三个与两个正样本不相似的样本。

 

图 7:正负样本与 Anchor

然而,在自监督学习任务中,我们并不知道每个样本的标签。因此,我们也无从知晓两张图像是否相似。

尽管如此,如果我们假设每张图片都从属于它自身的一个独有的类别,那么我们就可以提出各种构造这类三元组的方法(正负样本对)。这意味着,在一个包含 N 个样本的数据集中,我们现在拥有了 N 个标签!

图 8:为每一个样本赋予一个独特的类别

当我们知道了每一张图像的标签(类别)后,就可以使用数据增强技术来生成这些三元组。

4特性1:数据增强过程

首先,我们可以通过定义一个数据增强过程来描述一种对比式自监督学习方法。

一个数据增强过程 A(x) 对同一个输入应用一系列随机变换。

图 9:应用于某一输入的随机数据增强过程

在深度学习场景下,数据增强旨在构建对于原始输入中的噪声具有不变性的表征。例如,即使图 9 中的猪被旋转了、或者颜色消失了、甚至是像素被“抖动”了,网络还是能将其识别出来。

在对比学习场景下,数据增强还有第二个目标:生成 anchor、正样本、负样本,将它们输入给编码器,并将其用于提取表征。

CPC

CPC 引入了应用色彩抖动、随机灰度、随机翻转等变换的处理流程,但是它也引入了一种特殊的变换:将一张图像划分为一些重叠的子图块。

 

图 10:CPC 中的关键变换

通过使用这一过程,CPC 可以生成多组正负样本。实际上,该过程可以被应用于一批示例上,此时我们可以将批中其它的示例用作负样本。

图 11:根据一批图像生成正样本、anchor、负样本对。

AMDIM

与 CPC 相比,AMDIM 使用了一种稍微有些不同的方法。在进行了一些标准变换(抖动、翻转等)后,对于每一张图像,它都会通过将数据增强过程在该图像上应用两次得到两个版本的变换图像。

图 12:数据增强过程

实际上,Dosovitski 等人于 2014 年就在论文“Discriminative UnsupervisedFeature Learning with Convolutional Neural Networks”中提出了这一思想。这一思想旨在使用一个”种子“图像生成相同图像的许多变换版本。

SimCLR、Moco、Swav、BYOL

AMDIM 的工作流程取得了非常好的效果,以至于跟进该方法的所有工作都采用了相同的工作流程,它们对之前使用的变换方法采取了一些轻微的调整(例如,有的加入了抖动,有的加入了高斯模糊,等等)。然而,与 AMDIM 的主要思想相比,大多数这些变换方法都是不合逻辑的。

在本文中,我们针对这些变化的影响展开了消融实验。我们发现,对变换的选取对于方法最终的性能是十分关键的。事实上,我们相信,这些方法的成功大部分都是由特定的对变换的选择驱使的。

这些发现与 SimCLR 和 BYOL 论文中展示的相似的实验结果是相符的。

下面的视频详细介绍了 SimCLR 的工作流程:

5特性2:编码器

第二种描述这些方法的方式是:对于编码器的选择。上述的大多数方法使用了具有各种各样深度和宽度的ResNet 类网络。

图 13:ResNet 网络架构

当这类方法开始出现时,CPC 和 AMDIM 实际上设计了自定义的编码器。我们通过消融实验发现了,AMDIM 的泛化性能欠佳,而 CPC 则受编码器改变的影响较小。

图 14:在 CIFAR-10 数据集上测试编码器的鲁棒性

自从 CPC 之后,所有的方法都选用了 ResNet-50网络架构。尽管可能还有更优的架构有待发现,但是以 ResNet-50 作为标准架构说明我们可以重点关注如何提升其它的特性,从而通过更好的训练方法(而不是更好的架构)获得性能的提升。

然而,对于消融实验中的每一种情况而言,有一个发现始终成立:更宽的编码器在对比学习任务中性能要好得多。

6特性3:表征提取

第三种描述这类方法的方式是:它们采取的提取表征的策略。可以说,这也许正是所有这些方法产生“魔力”的秘诀,也是它们差别最大之处。

为了理解提取表征的策略如此重要的原因,让我们首先定义一下何为“表征”。“表征”是独特特性的集合,它使一个系统(以及人类)可以理解某物与其它物体的区别。

Quora的这篇名为“What is representation learning in deep learning?”的博文使用了一个实例来说明如何试图对形状进行分类。要想成功地对形状进行分类,在该形状中找到的“角”的个数可能是一种很好的表征。

图 15:各种拥有不同“角”数的形状

在这些对比学习方法中,它们通过各种各样的方式提取出表征。

CPC

CPC 引入了通过预测潜在空间中的“未来”情况来学习表征的思想。实际上,这意味着:(1)将一张图像看做延时间轴自左上向右下展开,其左上角为“过去”,而右下角为“未来”。

图 16:CPC 的“未来”预测任务

(2)预测结果并不是发生在像素级别上,而是编码器的输出(即潜在空间)。

图 17:从像素空间到潜在空间的变换

最终,CPC 通过将编码器的输出(H)作为投影头(作者称其为一个上下文编码器)生成的上下文向量的目标,定义了一个预测任务,从而进行表征提取。

图 18:CPC 的表征提取

在我们的论文“A Framework ForContrastive Self-Supervised Learning And Designing A New Approach”中,我们发现:只要数据增强过程足够强,那么这种预测任务并不是必需的。并且,尽管有许多关于数据增强过程的假设,我们认为一个强大的数据增强过程会创建共享某种相似的全局结构、但拥有不同局部结构的正样本对。

AMDIM

另一方面,AMDIM 采用的思想是:利用提取自卷积神经网络(CNN)中间层的各个特征图,对比不同视图的表征。我们可以从两个方面来分析这一过程:(1)图像的多视图(2)CNN 的中间层。

首先,我们不妨回想一下 AMDIM 为同一张图像生成两个版本的数据增强变体的过程。

 

图 19:AMDIM 数据增强过程 

每个版本的数据增强结果都会被传入给相同的编码器,从而为每张图像提取特征图。AMDIM 并不会丢弃由编码器生成的中间特征图,而是会将这些特征图用来进行跨空间尺度的比较。回想一下,当一张输入图像经过 CNN 的各个层时,感受野会在不同尺度上对输入进行信息编码。

图 20:不同尺度的特征图

AMDIM 通过对 CNN 的中间输出进行比较来实践了这一思想。图 21 说明了这一比较过程是如何在三张由编码器生成的特征图之间进行的。

图 21:AMDIM 的表征提取——AMDIM 使用相同的编码器提取 3 组特征图,并对它们进行比较。

其余的此类方法都针对 AMDIM 提出的思想进行了一些微调。

SimCLR

SimCLR 使用了与 AMDIM 相同的思想,但是做出了 2 处微调:

(1)仅仅使用最后的特征图;

(2)利用一个投影头处理该特征图,并比较投影前后的两个向量(与 CPC 中的上下文投影相似)。

Moco

正如前文所提到的,对比学习需要用到负样本。通常而言,这是通过将 batch 中的某张图像与其它图像进行比较而实现的。

Moco 进行了与 AMDIM 相同的处理过程(仅仅用到了最后的特征图),但它保留了处理过的所有batch 的历史记录,并以此增加负样本的数量。这样做的效果是:用于提供对比信号的负样本数增加,它超过了单个batch 所得到的负样本数。

图 22:基于动量编码器的对比学习

BYOL

BYOL 采用了与 AMDIM 相同的思想(但只用到了最后的特征图),但是进行了两处改变。

图 23:BYOL 架构示意图

(1)BYOL 用到了两个编码器。第二个编码器实际上完全是第一个编码器的副本,但是它不会在每一轮更新权重,而是使用一种滚动均值(rolling average)更新它们。

(2)BYOL 并没有用到负样本,而是依靠滚动权值更新作为一种为训练提供对比信号的方式。然而,近期的一项消融实验发现,这种做法可能并不是必需的,而且事实上加入批量归一化可以确保系统不生成平凡解。

Swav

Caron 等人在论文“UnsupervisedLearning of Visual Features by Contrasting Cluster Assignments”中将他们的表征提取任务构建为一种“在线聚类”,其中他们迫使“同一张图像不同的增强结果编码之间的一致性”得以满足。因此,Swav 采用了与 AMDIM 相同的方法(仅仅使用最后的特征图),但是它没有直接比较向量,而是通过一组 K 个预先计算出的编码计算相似度。

图 24:Swav 工作流程示意图

实际上,这意味着 Swav 会生成 K 个聚类,对于每个编码的向量而言,它会对这些聚类进行比较,从而学习出新的表征,这份工作可以看做将 AMDIM 和论文“Unsupervised Learning byPredicting Noise”的思想进行了融合。

7关于特性3的思考

表征的提取策略正是这些方法的不同之处。然而,它们之间的变化非常微妙,并没有进行严格的消融实验。我们很难确定是什么真正导致了相应的结果发生。

根据我们的实验结果,我们发现,CPC 和 AMDIM 策略对于结果的影响可以忽略不计,反而增加了计算复杂度。使这些方法奏效的主要驱动力是数据增强过程。

8特性4:相似度度量

第四个我们可以用来比较这类方法的特性是:它们使用的相似度度量方法。上述所有的方法都使用了一个点积或余弦相似度。尽管我们的论文并没有列举出这些消融实验情况,但我们的实验结果说明:对于相似度的选择在很大程度上是无关紧要的。

9特性5:损失函数

我们用来对比这些方法的第五种特性是:对损失函数的选择。所有这类方法(除了 BYOL)都选择使用了一种噪声对比估计(NCE)损失。NCE 损失函数包含两个部分:一个分子、一个分母。分子鼓励相似的向量靠近,分母推动所有其它的向量远离。

图 25:NCE 损失 

如果没有分母,那么损失就会变成一个常数,因此学到的表征就会不再适用。

然而,BYOL 并不需要该分母,而是依赖于第二个编码器的权重更新机制来提供对比信号。然而,正如前文所述,近期的一些消融实验说明,实际上这可能并不是驱动对比信号的因素。

在下面的视频中,我使用 SimCLR 作为一个示例,对NCE 损失进行了完整的解释。

10另一种DIM(YADIM)

我们希望通过生成一种新的方法来说明我们框架的有效性,该方法可以进行不需要前置动机或不涉及表征提取策略的自监督学习。我们将这种新的方法称为 YADIM。

YADIM 的特性如下:

特性 1:数据增强过程

在 YADIM 中,我们融合了 CPC 和 AMDIM 的数据增强过程。

特性 2:编码器

我们使用了 AMDIM 中的编码器,尽管使用任意其它编码器(如 ResNet-50)也有效。

特性 3:表征提取

YADIM 策略是简单的:对某张图像的多个版本进行编码,并使用最后的特征图进行对比。在该方法中,我们没有使用投影头或者其它复杂的对比策略。

特性 4:相似度度量

我们在 YADIM 中坚持使用点积。

特性 5:损失函数

我们也使用了 NCE 损失函数。

YADIM 的实验结果

尽管我们唯一有意义的选择是:融合 AMDIM 和 CPC 的数据增强过程,但是相较于其它的方法,YADIM 仍然成功地取得了优秀的性能。

图 26:对比实验结果

与所有相关的工作不同,我们通过真正亲自实现每一种方法生成了上述结果。实际上,据我们所知,我们实现的 CPC V2 版本是第一个 DeepMind 之外的公开实现版本。

更重要的是,我们使用 PyTorchLightning 标准化了所有的实现,因此我们可以客观地提取出所有上述结果背后的主要驱动因素。

计算效率

上述方法是使用大量计算资源训练出来的。高昂的计算开销意味着我们并没有进行严格的超参数搜索,而只是简单地使用了 STL-10 的超参数在 ImageNet 上进行训练。

使用 PyTorch Lightning 进行高效的分布式计算,我们可以将在 ImageNet 上使用 16 位精度每epoch的训练时间下降至 3 分钟。

基于 23dn.24xlarge 示例的训练每小时需要花费 31.212 美元,下面是我们用于每种方法的计算资源:

图 27:训练 AMDIM 的资源耗费情况

图 28:训练 CPC 的资源耗费情况

图 28:训练 SimCLR 的资源耗费情况

图 28:训练 YADIM 的资源耗费情况

11、要点回顾

  • 为了比较各种对比学习方法并且更容易地对其进行设计,我们引入了一种概念框架。

  • AMDIM、CPC、SimCLR、Moco、BYOL 以及  Swav 之间的差异非常微小。主要的不同之处在于它们提取表征的方式。

  • AMDIM 和 CPC 提出了被其它方法采用的关键思想。SimCLR、Moco、BYOL 以及 Swav 可以看做 AMDIM 的变体。

  • 只要编码器够宽,对于编码器的选择并没有影响。

  • 只要数据增强过程生成良好的正负样本输入,表征提取策略并没有太大的影响。

  • 通过使用我们的框架,我们可以形式化定义一种新的对比自监督学习方法“YADIM”,它与其它的竞争方法性能相当。

  • 训练这类方法的巨大计算开销意味着:在世界上,只有有限的研究组可以在该领域持续取得进展。尽管如此,我们以一种标准化的方式表述了所有这些算法,这至少可以减轻实现这些算法并对实现进行验证的困难。

  • 由于大多数实验结果都是由更宽的网络和特定的数据增强过程驱动的,我们猜想当前的研究方向的提升空间可能较为有限。

原文链接:

https://towardsdatascience.com/a-framework-for-contrastive-self-supervised-learning-and-designing-a-new-approach-3caab5d29619

编辑:于腾凯

校对:林亦霖

这篇关于对比自监督学习浪潮迅猛来袭,你准备好了吗?的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/901306

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

线性代数|机器学习-P36在图中找聚类

文章目录 1. 常见图结构2. 谱聚类 感觉后面几节课的内容跨越太大,需要补充太多的知识点,教授讲得内容跨越较大,一般一节课的内容是书本上的一章节内容,所以看视频比较吃力,需要先预习课本内容后才能够很好的理解教授讲解的知识点。 1. 常见图结构 假设我们有如下图结构: Adjacency Matrix:行和列表示的是节点的位置,A[i,j]表示的第 i 个节点和第 j 个

Node.js学习记录(二)

目录 一、express 1、初识express 2、安装express 3、创建并启动web服务器 4、监听 GET&POST 请求、响应内容给客户端 5、获取URL中携带的查询参数 6、获取URL中动态参数 7、静态资源托管 二、工具nodemon 三、express路由 1、express中路由 2、路由的匹配 3、路由模块化 4、路由模块添加前缀 四、中间件