【深度学习】详解 SimCLR

2023-10-20 03:50
文章标签 学习 详解 深度 simclr

本文主要是介绍【深度学习】详解 SimCLR,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


目录

摘要

一、引言

二、方法

2.1 The Contrastive Learning Framework

2.2. Training with Large Batch Size 

2.3. Evaluation Protocol

三、用于对比表示学习的数据增广 

3.1 Composition of data augmentation operations is crucial for learning good representations 

3.2 Contrastive learning needs stronger data augmentation than supervised learning 

四、编码器和头部的架构 

4.1 Unsupervised contrastive learning benefits (more) from bigger models

4.2 A nonlinear projection head improves the representation quality of the layer before it 

五、损失函数和 Batch Size

5.1 Normalized cross entropy loss with adjustable temperature works better than alternatives

5.2 Contrastive learning benefits (more) from larger batch sizes and longer training

六、与 SOTA 方法的比较


  • Title:A Simple Framework for Contrastive Learning of Visual Representations
  • Paper:https://arxiv.org/pdf/2002.05709.pdf
  • Github:GitHub - google-research/simclr: SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners

摘要

        本文提出了 SimCLR:一个简单的视觉表示对比学习的框架。我们简化了近期被提出的对比自监督学习算法,而无需专门的架构或内存库 (memory bank)。为理解是什么使对比预测任务能够学习有用的表示,本文系统地研究了我们的框架的主要组件,表明:

  1. 数据增广组合 在定义有效的预测任务中扮演关键角色。
  2. 在 表示和对比损失之间 引入一个 可学习的非线性转换,大大提高了已学习到的表示的质量。
  3. 相比于有监督学习,对比学习受益于 更大的 batch sizes 和更多的训练 steps

        通过结合这些发现,我们能够大大超过以前在 ImageNet 上的自监督和半监督学习方法。由 SimCLR 学习到的自监督表示 训练的线性分类器 达到了 76.5% 的 top-1 acc,比先前 SOTA 提高了 7%,与有监督 ResNet-50 的性能相匹配。当只对 1% 的标签微调时,达到了 85.8% 的 top-5 acc,在少 100× 标签的条件下优于 AlexNet。


一、引言

        在没有人类监督的情况下学习有效的视觉表征是一个长期存在的问题 (long-standing problem)。大多数主流的方法可分为 (fall into) 两类:生成式 (generative) 判别式 (discriminative)生成式方法 学习在输入空间中生成或建模像素。然而,像素级生成的计算成本很昂贵 (computationally expensive),并且可能不是表示学习所必需的判别式方法 使用类似于有监督学习的 objective 函数来学习表示,但训练网络执行前置任务 (pretext task),前置任务的输入和标签都来自于一个未经标记的数据集。许多此类方法都依赖于启发式 (heuristics) 来设计前置任务,这可能会限制学习表征的通用性 (generality)。基于潜在 (latent) 空间对比学习的判别方法最近显示出了巨大的前景,取得了 SOTA 的结果。

        在本工作中,我们引入了一个简单的视觉表示的对比学习框架 SimCLR,它不仅比以前的工作表现更好 (图 1),而且也更简单,既不需要专门的架构,也不需要内存库。

        为理解是什么使对比表示学习变得良好,本文系统地研究了我们的框架的主要组成部分,并表明:

  • 多数据增强操作的组合在定义 产生有效表示的对比预测任务时至关重要。此外,相比有监督学习,无监督对比学习受益于更强的数据增广。
  • 在表示和对比损失之间引入一个可学习的非线性变换,大大提高了学习到的表示的质量。
  • 具有对比交叉熵损失的表示学习受益于经归一化嵌入和经适当调整的温度参数。
  • 相比有监督学习的竞争方法,对比学习受益于更大的 batch sizes 和更长的训练。与监督学习一样,对比学习也受益于更深和更宽的网络。

二、方法

2.1 The Contrastive Learning Framework

        受最近的对比学习算法的启发 (见第 7 节的概述),SimCLR 通过潜在空间 (latent space) 中的对比损失,以最大化同一数据示例的不同增广视图 (views) 之间的一致性来学习表示。如图 2 所示,该框架包括以下 4 个主要组件。

  • 一个随机数据增广模块,转换任意给定的数据示例,随机得到同一示例的 2 个相关视图 \widetilde{x}_i 和 \widetilde{x}_j,二者被视为一个 正对 (positive pair)。本工作中,依次应用 3 个简单的增广:随机 crop 后 resize 回原尺寸、随机颜色失真 (distortions)、随机高斯模糊。如第 3 节所示,随机 crop 和颜色失真的组合是实现良好性能的关键。
  • 一种基于神经网络的编码器 f( \cdot ),从经增广的数据示例中提取表示向量。本框架允许选择各种网络架构而没有任何约束。我们选择了 (opt for) 简单性并使用常用的 ResNet 获取 h_i = f( \widetilde{x}_i ) = \textrm{ResNet} ( \widetilde{x}_i ),其中 h_i \in \mathbb{R}^d 为 平均池化层后输出的表示向量
  • 一个小型神经网络投影头部 g ( \cdot ),将表示向量映射到对比损失空间。我们使用一个带单隐藏层的 MLP 获取 z_i = g(h_i) = W^{(2)}\sigma( W^{(1)} h_i),其中 \sigma 是 ReLU。如第 4 节所示,我们发现定义 z_i 而非 h_i 的对比损失是有益的 (find it beneficial to)。
  • 一个对比损失函数 —— 为一个对比预测任务而定义的。给定一个集合 \{ \widetilde{x}_k \},其中包含一个正对样例 \widetilde{x}_i 和 \widetilde{x}_j。对比预测任务旨在针对给定的 \widetilde{x}_i 识别 (identify)  \{ \widetilde{x}_k \}_{k \neq i} 中的 \widetilde{x}_j

        我们随机采样一个包含 N 个样本的 minibatch,并基于其中增广后的样本对 定义对比预测任务,得到 2N 个数据点 (每个样本增广一次)。我们没有显式地采样负样本 (negative exmples)。相反,给定一个正对,类似于 (On sampling strategies for neural network-based collaborative filtering),我们 将一个 minibatch 中的其他 2(N - 1) 个增广后的样本视为负样本。设 \textrm{sim} (u, v) = u^{\textrm{T}} v / \left \| u \right \| \left \| v \right \| 表示经 L_2 归一化的 u 和 v 之间的点积 (即 余弦相似度)。然后将正对样本 (i, j) 的损失函数定义为:

        其中 1_{[k \neq i]} \in \{ 0, 1 \} 是一个指示函数,当 k \neq i 值为 1,否则为 0,而 \tau 表示一个温度参数。最终的损失计算所有的正对,包括 (i, j) 和 (j, i)。这种损失已用于之前的工作;为方便起见,称其为 NT-Xent (the normalized temperature-scaled cross entropy loss)


2.2. Training with Large Batch Size 

        为保持简单 (to keep it simple),不使用内存库 (memory bank) 训练模型。相反,我们将使用 N 从 256 到 8192 的 batch size 训练。batch size = 8192 时每个正对有 16382 个负样本 (2(N - 1))。当使用具有线性学习率 scaling 的标准 SGD / Momentum 时,大 batch size 的训练可能不稳定。为稳定训练,我们对所有 batch sizes 都使用了 LARS 优化器。云 TPU 被用于训练模型,根据 batch size 使用 32 到 128 个核 (使用 128 个 TPU v3 核时,训练 100 个 epochs 的 batch size = 4096 的 ResNet-50 需要 ∼1.5 小时)。

        Global BN。标准 ResNets 使用 BN。在具有数据并行性 (parallelism) 的分布式训练中,BN 均值和方差通常局部地聚合在每个设备。在我们的对比学习中,由于 正对是在同一设备中计算的,模型由此可以利用 局部信息泄漏,在不改善表示形式的情况下 (错误地) 提高预测 acc。我们通过 在训练期间聚合所有设备的 BN 均值和方差 来解决该问题 (address this issue by)。其他方法包括 shuffling 跨设备的数据示例,或 LN 替换 BN


2.3. Evaluation Protocol

        在这里,我们制定了实证研究方案,旨在理解框架中不同的设计选择。

        Dataset and Metrics。我们对无监督预训练 (学习编码器网络 f 而无需标签) 的大部分研究都是使用 ImageNet ILSVRC-2012 数据集 完成的。一些关于 CIFAR-10 的额外预训练实验可以在附录 B.9 中找到。我们还在广泛的迁移学习数据集上测试了预训练的结果。为评估学习到的表示,我们遵循广泛使用的 线性评估协议 (linear evaluation protocol) —— 在已冻结的基础网络上训练一个线性分类器,并将 测试 acc 用作表示质量的代理 (proxy)。除了线性评估,我们还比较了 SOTA 的半监督和迁移学习。

        Default setting。除非另有说明 (unless otherwise specified),对于数据增广,我们使用随机 crop 和 resize (带随机 flip)、颜色 distortions 和高斯模糊 (详见附录 A)。我们使用 ResNet-50 作为基础编码器网络,并使用一个 2 层 MLP 投影头部来将表示投影到一个 128-d 的潜在空间。我们使用 NT-Xent 损失,使用 LARS 优化,学习率为 4.8 (= 0.3 × BatchSize / 256),权重衰减为 1e-6。我们以 batch size = 4096 训练了 100个 epochs (虽然在 100 个 epochs 没有达到最大性能,但取得了合理的结果,允许公平和有效的消融)。此外,我们在前 10 个 epochs 使用线性warmup,并通过余弦衰减策略来衰减学习率而无需重启 (restarts)。


三、用于对比表示学习的数据增广 

        Data augmentation defines predictive tasks。当数据增广已被广泛应用于有监督和无监督的表示学习时,它尚未被认为是 一种定义对比预测任务的 系统的方法。许多现有方法 通过改变架构 来定义对比预测任务。例如,有的通过约束网络架构中的感受野 实现 全局到局部的视图预测,而有的通过固定图像 splitting 过程和上下文聚合网络 实现 邻近视图预测。我们证明,这种复杂性可以通过执行简单的目标图像的随机 crop (带 resize) 来避免,这将创建一个包含上述两个任务的预测任务 family,如图 3 所示。这种简单的设计选择 可以方便地将预测任务与其他组件,如神经网络架构解耦。更广泛的对比预测任务 可以通过扩展增广的 family 及其随机组合 来定义。


3.1 Composition of data augmentation operations is crucial for learning good representations 

        为系统地研究数据增广的影响,我们考虑了几个常见的增广。一种类型的增广涉及数据的 空间/几何变换,如 crop 和 resize (带水平 flip)、rotation 和 cutout。另一种类型的增广涉及 外观变换,如颜色 distortion (包括颜色 dropping、亮度、对比度、饱和度、色调)、高斯模糊 和 Sobel 滤波。图 4 可视化了我们在这项工作中研究的增广。

        为理解 单个数据增广的影响增广组合的重要性,我们研究了本框架在单独或成对应用增广时的性能。由于 ImageNet 图像的大小不同,我们总是应用 crop 和 resize 的图像,这使得在没有 cropping 的情况下研究其他增广变得困难。为消除这种混淆 (to eliminate this confound),我们考虑了针对这种消融的非对称数据转换设置。具体来说,我们总是首先随机 crop 图像并将其 resize 为相同的分辨率,然后 只将目标转换应用于图 2 中框架的一个分支,而保留另一个分支不变作为 identity (即 t(x_i) = x_i)。注意,这种不对称 —— 图 5 展示了单个和组合转换下的线性评估结果。我们观察到,没有单个转换足以学习到好的表示,即使模型几乎可以完美地识别对比任务中的正对。当组合增广时,对比预测任务变得更加困难,但表示的质量显著提高。附录 B.2 提供了关于更广泛的增广集组合的进一步研究。

        增广的一个组成部分很突出 (stands out):随机 cropping 和随机的颜色 distortion。我们推测,当只使用随机 cropping 作为数据增广时,一个 严重的问题 是,图像中的大多数 patches 具有相似的颜色分布。图 6 显示,仅彩色直方图就足以 (suffice to) 区分图像。神经网络可以利用这一捷径来解决预测任务。因此,为学习可泛化的特征,用颜色 distortion 来组合 cropping 是至关重要的

# color distortion - TensorFlow import tensorflow as tfdef color_distortion(image, s=1.0):# image is a tensor with value range in [0, 1].# s is the strength of color distortion.def color_jitter(x):# one can also shuffle the order of following augmentations# each time they are applied.x = tf.image.random_brightness(x, max_delta=0.8*s)x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)x = tf.image.random_saturation(x, lower=1-0.8*s, upper=1+0.8*s)x = tf.image.random_hue(x, max_delta=0.2*s)x = tf.clip_by_value(x, 0, 1)return xdef color_drop(x):image = tf.image.rgb_to_grayscale(image)image = tf.tile(image, [1, 1, 3])# randomly apply transformation with probability p.image = random_apply(color_jitter, image, p=0.8)image = random_apply(color_drop, image, p=0.2)return image

3.2 Contrastive learning needs stronger data augmentation than supervised learning 

        为了进一步证明颜色增广的重要性,我们调整了颜色增广的强度,如表 1 所示。更强的颜色增广大大提高了 (substantially improves) 学习到的无监督模型的线性评估。在这种情况下,AutoAugment,一种使用监督学习发现的复杂增广策略,并不比简单的 cropping + (更强的) 颜色 distortion 好当使用相同的增广集合训练有监督模型时,我们观察到更强的颜色增广不会改善、甚至还会损害它们的表现。因此,我们的实验表明,相比于有监督学习,无监督对比学习受益于更强的 (颜色) 数据增广。尽管之前的研究报告称,数据增广对自监督学习很有用,我们表明,不能为有监督学习产生 acc 收益的数据增广仍可极大地帮助 (help considerably) 对比学习


四、编码器和头部的架构 

4.1 Unsupervised contrastive learning benefits (more) from bigger models

        图 7 显示了增加深度和宽度都能不出意外地提高性能。虽然类似的发现也适用于有监督学习,但我们发现 有监督模型和无监督模型训练的线性分类器之间的差距 随着模型尺寸的增加而缩小,这表明 无监督学习在模型尺寸扩大时比有监督学习收益更多


4.2 A nonlinear projection head improves the representation quality of the layer before it 

        然后,我们研究了包括一个投影头部的重要性,即 g(h)。图 8 显示了使用 3 种不同架构的头部线性评估结果:(1) 恒等映射;(2) 线性投影;(3) 带有一个隐藏层的默认非线性投影 (和 ReLU 激活)。我们观察到,非线性投影 比 线性投影 (+3%) 和没有投影 (>10%) 更好。当使用投影头部时,无论输出尺寸如何,都会观察到类似的结果。此外,即使使用非线性投影,投影头部前的层 h 仍然比投影头部后的层 z = g (h) 好得多 (>10%),这表明 投影头部前的隐藏层 比 投影头部后的层 有更好的表示

        我们推测,在非线性投影之前使用表示的重要性 是由于对比损失引起的信息损失。特别地,z = g (h) 被训练为 对数据变换具有不变性。因此,g 可以删除可能对下游任务有用的信息,例如 objects 的颜色或方向。通过利用非线性变换 g ( \cdot ),可以在 h 中构成并保持更多的信息。为验证该假设,我们进行了实验,使用 h 或 g (h) 来学习预测 在预训练过程中应用的变换。此处设置 g(h) = W^{(2)} \sigma ( W^{(1)} h),输入和输出维数一致(即 2048)。表 3 显示了 h 包含了更多关于所应用的转换的信息,而 g (h) 丢失了这些信息。更深入的分析详见附录 B.4。


五、损失函数和 Batch Size

5.1 Normalized cross entropy loss with adjustable temperature works better than alternatives

        我们将 NT-Xent 损失与其他常用的对比损失函数进行了比较,如 logistic 损失 和 margin 损失。表 2 显示了 objective 函数以及损失函数对于输入的梯度。从梯度中我们观察到:

  1. L_2 归一化 (即余弦相似度) 及温度有效地加权不同的样本,适当的温度可以帮助模型从难负样本中学习;
  2. 与交叉熵不同,其他 objective 函数不通过相对困难度来衡量负样本。

        因此,必须对这些损失函数应用半难负挖掘 (semi-hard negative mining, SHNM):而不是在所有损失项计算梯度,可以使用半难负项 (semi-hard negative terms) (即那些在损失 margin 内距离最近,但比正样本更远的) 计算梯度 。

        为了使比较公平,我们对所有损失函数使用相同的 L_2 归一化,并调整超参数,报告它们的最佳结果。表 4 显示 (详情见附录 B.10。为简单起见,我们只考虑从一个增广视图中得到负样本),虽然 (半难) 负挖掘有帮助,但最好的结果仍然比我们默认的 NT-Xent 损失要差得多。

        接下来,我们测试了 L_2 归一化 (即余弦相似度与点积) 和温度 \tau 在我们默认的 NT-Xent 损失中的重要性。从表 5 可以看出,如果没有归一化和适当的温度尺度,性能明显较差。在没有 L_2 归一化的情况下,对比任务 acc 更高,但线性评价下的表示更差


5.2 Contrastive learning benefits (more) from larger batch sizes and longer training

        图 9 显示了当模型针对不同 epoch 数训练时,batch size 的影响。我们发现,当训练 epochs 数较小时 (如 100),较大的 batch size 比较小的 batch size 具有更显著的优势。随着训练 steps/epochs 的增加,不同 batch size 大小之间的差距会减小或消失 (只要 batches 是随机重采样的)。相比于有监督学习,在对比学习中,更大的 batch size 提供了更多的负样本,促进了收敛 (即在给定 acc 下耗费更少的 steps/epochs)。更长时间的训练也提供了更多负样本,改善了结果。在附录 B.1 中,提供了具有更长的训练 steps 的结果。


六、与 SOTA 方法的比较

这篇关于【深度学习】详解 SimCLR的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/244470

相关文章

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

SQL注入漏洞扫描之sqlmap详解

《SQL注入漏洞扫描之sqlmap详解》SQLMap是一款自动执行SQL注入的审计工具,支持多种SQL注入技术,包括布尔型盲注、时间型盲注、报错型注入、联合查询注入和堆叠查询注入... 目录what支持类型how---less-1为例1.检测网站是否存在sql注入漏洞的注入点2.列举可用数据库3.列举数据库

Linux之软件包管理器yum详解

《Linux之软件包管理器yum详解》文章介绍了现代类Unix操作系统中软件包管理和包存储库的工作原理,以及如何使用包管理器如yum来安装、更新和卸载软件,文章还介绍了如何配置yum源,更新系统软件包... 目录软件包yumyum语法yum常用命令yum源配置文件介绍更新yum源查看已经安装软件的方法总结软

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Java访问修饰符public、private、protected及默认访问权限详解

《Java访问修饰符public、private、protected及默认访问权限详解》:本文主要介绍Java访问修饰符public、private、protected及默认访问权限的相关资料,每... 目录前言1. public 访问修饰符特点:示例:适用场景:2. private 访问修饰符特点:示例:

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

详解Java如何向http/https接口发出请求

《详解Java如何向http/https接口发出请求》这篇文章主要为大家详细介绍了Java如何实现向http/https接口发出请求,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 用Java发送web请求所用到的包都在java.net下,在具体使用时可以用如下代码,你可以把它封装成一

JAVA系统中Spring Boot应用程序的配置文件application.yml使用详解

《JAVA系统中SpringBoot应用程序的配置文件application.yml使用详解》:本文主要介绍JAVA系统中SpringBoot应用程序的配置文件application.yml的... 目录文件路径文件内容解释1. Server 配置2. Spring 配置3. Logging 配置4. Ma

mac中资源库在哪? macOS资源库文件夹详解

《mac中资源库在哪?macOS资源库文件夹详解》经常使用Mac电脑的用户会发现,找不到Mac电脑的资源库,我们怎么打开资源库并使用呢?下面我们就来看看macOS资源库文件夹详解... 在 MACOS 系统中,「资源库」文件夹是用来存放操作系统和 App 设置的核心位置。虽然平时我们很少直接跟它打交道,但了

关于Maven中pom.xml文件配置详解

《关于Maven中pom.xml文件配置详解》pom.xml是Maven项目的核心配置文件,它描述了项目的结构、依赖关系、构建配置等信息,通过合理配置pom.xml,可以提高项目的可维护性和构建效率... 目录1. POM文件的基本结构1.1 项目基本信息2. 项目属性2.1 引用属性3. 项目依赖4. 构