Hybrid-PSC:基于对比学习的混合网络,解决长尾图片分类 | CVPR 2021

本文主要是介绍Hybrid-PSC:基于对比学习的混合网络,解决长尾图片分类 | CVPR 2021,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 论文提出新颖的混合网络用于解决长尾图片分类问题,该网络由用于图像特征学习的对比学习分支和用于分类器学习的交叉熵分支组成,在训练过程逐步将训练权重调整至分类器学习,达到更好的特征得出更好的分类器的思想。另外,为了节省内存消耗,论文提出原型有监督对比学习。从实验结果来看,论文提出的方法效果还是很不错的,值得一看

来源:晓飞的算法工程笔记 公众号

论文: Contrastive Learning based Hybrid Networks for Long-Tailed Image Classification

  • 论文地址:https://arxiv.org/abs/2103.14267
  • 论文代码:https://www.kaihan.org/HybridLT/

Introduction


 在实际场景中,图片类别通常都会呈现长尾分布,不常见的类别通常由于数据不足而无法被充分学习,给分类器的学习带来巨大的挑战。当前大多研究都通过减轻尾部类别的数据短缺来应对数据不平衡的问题,防止模型被头部类别控制,如数据重采样和数据增强等。
 最近,有新的研究提出将长尾数据分类问题分解为特征学习和分类器学习两个阶段,认为这两个阶段适用不同的数据采样策略进行学习,比如随机采样更适合特征学习,而类别平衡采样更适合分类器学习。

 但有一点需要注意的是,上述两类研究都没有考虑到,在数据不平衡场景下,交叉熵损失是否仍为特征学习的理想损失函数。交叉熵损失学习到的特征分布可能会高度倾斜,如上图所示,导致分类器存在偏向性,会影响长尾分类。
 为此,论文研究了高效的对比学习策略,将其适配到不平衡数据中学习特征表达,提高长尾图片分类场景的性能。论文采用了新颖的混合网络结构,由用于特征表达学习的对比损失和用于分类器学习的交叉熵损失组成。两个损失联合训练,在训练过程中逐渐调整两个损失的权重,从特征学习逐步转移为分类器学习,遵循更好的特征产生更好的分类器的思想。

 论文一开始采用从无监督对比(UC)中延伸出来的有监督对比(SC)损失用于特征学习,该损失使用batch内的样本进行相互对比,通过区分负样本来优化正样本间的一致性,如图左所示。如果想要保证优化效果,需要确保对比的正样本够多以及负样本覆盖足够多的类别,通常需要使用较大的batch,导致内存消耗过多。为了解决这个问题,论文提出了原型有监督对比(PSC)学习策略,从batch内的样本间对比改为batch内的样本与额外维护的原型进行对比,如图右所示。在保持原本有监督对比的特性的情况下,原型有监督对比避免了过多的内存消耗,还能使数据采样更灵活和高效。
 论文的主要贡献如下:

  • 提出用于长尾数据分类的混合网络结构,由用于特征表达学习的对比损失和用于分类器学习的交叉熵损失组成。在训练过程中逐渐调整两个损失的权重,从特征学习逐步转移为分类器学习,遵循更好的特征产生更好的分类器的思想。
  • 研究高效的有监督对比学习策略用于更优的特征学习,提高长尾分类性能。另外,论文提出原型有监督对比来解决标准有监督对比的内存问题。
  • 验证在长尾分类场景中,有监督对比学习能更好地替代交叉熵损失进行特征学习。得益于学习到更好的特征,论文提出的混合网络能够极大地超越基于交叉熵的网络。

Contrastive learning


Unsupervised contrastive

 无监督对比学习在无标签的场景下,通过同源图片与非同源图片之间的特征对比来进行特征表达的学习。比如先随机选取n张原图片,经过数据增强后变成2n张图片组成batch,将同源副本相互认为正样本、非同源副本认为负样本进行距离学习。

Supervised contrastive

 有监督对比学习主在有标签的场景下,通过同类别图片与非同类别图片之间的特征对比来进行特征表达的学习。有监督对比学习也是需要进行数据增强生成同源副本的,所以正样本包含同源副本和同类别副本。比如选取n张原图片,经过数据增强后变成2n张图片组成batch,将同类图片相互认为正样本、非同类图片认为负样本进行距离学习。这里的n张图片选取不能随机选,为达到有监督的目的,同类别图片要大于1张。

Main Approach


A Hybrid Framework for Long-tailed Classification

 论文提出的用于长尾图像分类的混合框架如上图所示,包含两个分支:

  • 用于图像特征学习的对比学习分支,构造同类内聚、异类分离的特征空间。
  • 用于分类器学习的交叉熵分支,基于对比学习分支得到的显著特征学习类别偏向较少的分类器。

 为了达到用更好的特征帮助分类器进行学习,从而得到更通用的分类器的目的。论文参考了BBN的双分支联合训练方法,在训练阶段逐步调整这两个分支的权重。在训练初期以特征学习作为主导,随着训练的进行,分类器学习逐级主导训练。
 主干网络在分支间共享,共同帮助主干网络学习每个图片的特征 r ∈ R D E r\in\mathcal{R}^{D_E} rRDE。两个分支分别进行不同的操作:

  • 对比学习分支先通过MLP层 f e ( ⋅ ) f_e(\cdot) fe()将图片特征 r r r映射成向量表达 z ∈ R D S z\in\mathcal{R}^{D_S} zRDS,适配后续对比损失函数的计算。另外,这样的特征向量化转换也有助于提升前一层的特征质量。随后,对特征 z z z进行 l 2 \mathcal{l}_2 l2归一化,使其能够用于距离计算。最后,使用输出的归一化特征计算有监督对比损失 L S C L \mathcal{L}_{SCL} LSCL
  • 分类器学习分支先通过单个线性层从图像特征 r r r预测类别结果 s ∈ R D C s\in\mathcal{R}^{D_C} sRDC,随后直接计算交叉熵损失 L C E \mathcal{L}_{CE} LCE

 需要注意的是,为了适应其损失函数的特性,两个分支的数据采样方式是不同的。特征学习分支需要附带样本 x i x_i xi的同类正样本 { x i + } = { x j ∣ y i = y j , i ≠ j } \{x^{+}_i\}=\{x_j|y_i=y_j,i\ne j\} {xi+}={xjyi=yj,i=j}和异类负样本 { x i − } = { x j ∣ y i = y j , i ≠ j } \{x^{-}_i\}=\{x_j|y_i=y_j,i\ne j\} {xi}={xjyi=yj,i=j},组成单个batch输入 B S C = { x i , { x i + } , { x i − } } \mathcal{B}_{SC}=\{x_i, \{x^{+}_i\}, \{x^{-}_i\}\} BSC={xi,{xi+},{xi}},而分类器学习分支则直接输入图片和标签 B C E = { { x i , y i } } \mathcal{B}_{CE}=\{\{x_i, y_i\}\} BCE={{xi,yi}}即可。
 混合网络的最终损失函数为:

α \alpha α是权重因子,与周期数成反相关。

Supervised contrastive loss and its memory issue

 有监督对比损失(supervised contrastive loss, SC loss)是对无监督对比损失(unsupervised contrastive loss, UC loss)的扩展,区别在于单batch内的正负样本构成。假设目标图片的正负样本的向量特征为 { z i + } \{z^{+}_i\} {zi+} { z i − } \{z^{-}_i\} {zi},对于大小为N的minibatch,SC loss的计算为:

 相对于UC loss,SC loss可采用任意数量的正样本。由于对比损失是通过区分负样本来优化正样本间的一致性,所以负样本数量十分重要的,而SC损失加入同类图片作为正样本,为保证负样本数量而不得不成倍地增加batch大小,导致内存消耗成倍地增加,导致内存消耗的成倍地增加,限制了SC loss的使用场景。
 一个解决内存消耗的做法就是缩小负样本数量,但这样在类别数多的场景下会有问题。负样本数小意味着只能采样到少量负样本类别,肯定会影响学到的特征质量。

Prototypical supervised contrastive loss

 为了同时兼顾内存消耗和特征质量,论文提出了原型有监督对比损失(prototypical supervised contrastive loss, PSC loss),为每个类别学习一个原型,强迫每个图片的数据增强副本尽量靠近其所属类别的原型以及远离其他类别的原型。使用原型有两个好处:1)允许更灵活的数据采样方式,不再需要显示地控制正负样本,可使用随机采样或类别平衡采样。2)数据采样更高效,假设有 C \mathcal{C} C类别,则每次采样保证都有 C − 1 \mathcal{C}-1 C1个负样本,这对于类别多的数据集特别重要。
 PSC loss的计算如下:

p i j p_{ij} pij是类别 y i y_i yi的原型特征,归一化为 R D S \mathcal{R}^{D_S} RDS下的单位超球面,即满足L2归一化。这里没有提到原型是如何初始化和学习的,需要等源码放出来再看看。
 PSC loss也可以延伸为每个类别多个原型,主要为了迎合单类别可能存在有多种数据分布的情况。多原型有监督对比损失(multiple prototype supervised contrastive loss, MPSC loss)的计算为:

M M M为每个类别的原型数, p j i p^i_j pji为类别j的第 i i i个原型, w i , k ( w i , k ≥ 0 , ∑ k = 1 M ) w i , k = 1 w_{i,k}(w_{i,k}\ge 0,{\sum}^M_{k=1})w_{i,k}=1 wi,k(wi,k0,k=1M)wi,k=1 z i z_i zi与第 k k k个原型之间的关系值,用于更细粒度地控制每个样本,这将会在未来的工作中进行进一步地验证。

Experiment


Datasets

 论文主要在三个长尾图片分类数据集进行实验:

  • Long-tailed CIFAR-10和CIFAR-100:原版的CIFAR数据集是平衡的,通过减少每个类别的图片数来生成长尾版本,注意验证集不变。用一个不平衡比例 β = N m a x / N m i n \beta=N_{max}/N_{min} β=Nmax/Nmin来表示生成的长尾数据集的不平衡程度。
  • iNaturalist 2018:iNaturalist 2018是一个大型的生物品种数据集,包含8142个品种、437513张训练图片以及24424张验证图片。

Implementation details

 对于长尾CIFAR数据集和iNaturalist数据集,论文使用了不同的实验配置:

  • Implementation details for long-tailed CIFAR:混合网络使用ResNet-32作为主干,两个分支共享的数据增强方法有: 32 × 32 32\times 32 32×32的随机裁剪、水平翻转以及概率为0.2的随机灰度。另外,PSC loss也跟随SC loss使用额外的数据增强方法。在实验中,论文简单地使用有颜色扰动和无颜色扰动的图片作为数据增强副本对,batch size为512,使用momentum=0.9、weight decay= 1 × 1 0 − 4 1\times 10^{-4} 1×104的SGD优化器。网络共训练200个周期,学习率初始为0.5并在第120周期和160周期下降10倍。权重因子 α = 1 − ( T / T m a x ) 2 \alpha=1-(T/T_{max})^2 α=1(T/Tmax)2与周期数成抛物线衰减。对于SC loss,公式3的 τ \tau τ固定为0.1,而对于PSC loss,在CIFAR-10和CIFAR-100上分别设置为1和0.1。
  • Implementation details for iNaturalist 2018:混合网络使用ResNet-50作为主干网络,数据增强跟长尾CIFAR一样,只是随机裁剪的图片大小为 224 × 224 224\times 224 224×224,batch size为100。网络共训练100轮,使用momentum=0.9、weight decay= 1 × 1 0 − 4 1\times 10^{-4} 1×104的SGD优化器,学习率初始为0.05并在第60周期和第80周期下降10倍。考虑这个数据集的类别多,学习器训练较难,权重因子 α = 1 − T / T m a x \alpha=1-T/T_{max} α=1T/Tmax设置为线性下降,公式3的 τ \tau τ固定为0.1。对于SC loss,正样本数固定为2。

Result

 长尾CIFAR上的结果对比。

 iNaturalist 2018上的结果对比。

Conclusion


 论文提出新颖的混合网络用于解决长尾图片分类问题,该网络由用于图像特征学习的对比学习分支和用于分类器学习的交叉熵分支组成,在训练过程逐步将训练权重从特征学习调整至分类器学习,遵循更好的特征可得出更好的分类器的思想。另外,为了节省内存消耗,论文提出原型有监督对比学习。从实验结果来看,论文提出的方法效果还是很不错的,值得一看。

参考内容

  • [Supervised Contrastive Learning
    -Prannay Khosla
    ](https://arxiv.org/abs/2004.11362)



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

这篇关于Hybrid-PSC:基于对比学习的混合网络,解决长尾图片分类 | CVPR 2021的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/858883

相关文章

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

2024.6.24 IDEA中文乱码问题(服务器 控制台 TOMcat)实测已解决

1.问题产生原因: 1.文件编码不一致:如果文件的编码方式与IDEA设置的编码方式不一致,就会产生乱码。确保文件和IDEA使用相同的编码,通常是UTF-8。2.IDEA设置问题:检查IDEA的全局编码设置和项目编码设置是否正确。3.终端或控制台编码问题:如果你在终端或控制台看到乱码,可能是终端的编码设置问题。确保终端使用的是支持你的文件的编码方式。 2.解决方案: 1.File -> S

【Altium】查找PCB上未连接的网络

【更多软件使用问题请点击亿道电子官方网站】 1、文档目标: PCB设计后期检查中找出没有连接的网络 应用场景:PCB设计后期,需要检查是否所有网络都已连接布线。虽然未连接的网络会有飞线显示,但是由于布线后期整板布线密度较高,虚连,断连的网络用肉眼难以轻易发现。用DRC检查也可以找出未连接的网络,如果PCB中DRC问题较多,查找起来就不是很方便。使用PCB Filter面板来达成目的相比DRC

JAVA读取MongoDB中的二进制图片并显示在页面上

1:Jsp页面: <td><img src="${ctx}/mongoImg/show"></td> 2:xml配置: <?xml version="1.0" encoding="UTF-8"?><beans xmlns="http://www.springframework.org/schema/beans"xmlns:xsi="http://www.w3.org/2001

《offer来了》第二章学习笔记

1.集合 Java四种集合:List、Queue、Set和Map 1.1.List:可重复 有序的Collection ArrayList: 基于数组实现,增删慢,查询快,线程不安全 Vector: 基于数组实现,增删慢,查询快,线程安全 LinkedList: 基于双向链实现,增删快,查询慢,线程不安全 1.2.Queue:队列 ArrayBlockingQueue: