有监督对比学习的一个简单的例子

2024-06-21 08:08

本文主要是介绍有监督对比学习的一个简单的例子,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”

因公众号更改了推送规则,记得读完点“在看”~下次AI公园的新文章就能及时出现在您的订阅列表中


作者:Dimitre Oliveira

编译:ronghuaiyang

导读

使用有监督对比学习来进行木薯叶病害识别。

论文链接:https://arxiv.org/abs/2004.11362

监督对比学习(Prannay Khosla等人)是一种训练方法,它在分类任务上优于使用交叉熵的监督训练。

这个想法是,使用监督对比学习(SCL)的训练模型可以使模型编码器从样本学习更好的类表示,这应该导致更好的泛化,并对于图像和标签的错误更具鲁棒性。

在本文中,你将了解什么是监督对比学习,以及监督对比学习是如何工作的,你会看到代码实现、一个应用程序的例子,最后将看到SCL和常规交叉熵之间的比较。

简而言之,SCL就是这样工作的:

在嵌入空间中将属于同一类的聚类点聚在一起,同时将来自不同类的样本簇分离。

有许多对比学习方法,如" 监督对比学习"," 自监督对比学习"," SimCLR "等,它们的比对部分都是共同的,它们学习来自一个域的样本和来自另一个域的样本的差别,但SCL以监督的方式利用标签信息完成这项任务。

不同的训练方法的结构

本质上,用监督对比学习对分类模型进行训练分为两个阶段:

  1. 训练编码器,学习生成输入图像的向量表示,这样,同类别图像的表示将比不同类别图像的表示更加相似。

  2. 在参数冻结的编码器上训练一个分类器。

例子

我们将把监督比较学习应用于Kaggle竞赛的数据集(Cassava Leaf Disease Classification),目的是将木薯叶的图像分类为5类:

0: Cassava Bacterial Blight (CBB)
1: Cassava Brown Streak Disease (CBSD)
2: Cassava Green Mottle (CGM)
3: Cassava Mosaic Disease (CMD)
4: Healthy

我们有四种疾病和一种健康的叶子,下面是一些图像样本:

来自比赛的木薯叶图像样本

数据有21397图像用于训练,大约有15000图像用于测试集。

实验设置

  • 数据:图像分辨率512 × 512像素。

  • 模型(编码器):EfficientNet B3。

你可以在这里查看:https://www.kaggle.com/dimitreoliveira/cassava-leaf-supervised-contrastive-learning

通常,对比学习方法能更好地工作,如果每个训练一个batch都有每个类的样本,这将有助于编码器学会对比不同域之间的差别,这意味着需要使用一个大的batch size,在这种情况下,我已经对每个类进行了过采样,所以每个batch的样本中每个类样本的概率大致相同。

数据集中的类别分布,过采样之后

数据增强通常有助于计算机视觉任务,在我的实验中,我也看到了数据增强的改进,这里我使用剪切,旋转,翻转,作物,剪切,饱和度,对比度和亮度的变化,它可能看起来很多,但图像没有和原始图像有太大不同。

增强后的数据样本

现在我们可以看看代码了

编码器

我们的编码器将是一个“EfficientNet B3”,但是在编码器的顶部有一个平均池化层,这个池化层将输出一个2048大小的向量,稍后它将用于检查编码器学习到的表示。

def encoder_fn(input_shape):inputs = L.Input(shape=input_shape, name=’inputs’)base_model = efn.EfficientNetB3(input_tensor=inputs, include_top=False,weights=’noisy-student’, pooling=’avg’)model = Model(inputs=inputs, outputs=base_model.outputs)return model

投影头

投影头位于编码器的顶部,负责将编码器嵌入层的输出投影到更小的尺寸中,在我们的例子中,它将2048维的编码器投影到128维的向量中。

def add_projection_head(input_shape, encoder):inputs = L.Input(shape=input_shape, name='inputs')features = encoder(inputs)outputs = L.Dense(128, activation='relu', name='projection_head', dtype='float32')(features)model = Model(inputs=inputs, outputs=outputs)return model

分类头

分类器头用于的可选的第二阶段训练,在SCL 训练阶段之后,我们可以去掉投影头,把这个分类器头加到编码器上,并使用常规的交叉熵损失来finetune模型,这样做的时候,需要冻结编码器层。

def classifier_fn(input_shape, N_CLASSES, encoder, trainable=False):for layer in encoder.layers:layer.trainable = trainableinputs = L.Input(shape=input_shape, name='inputs')features = encoder(inputs)features = L.Dropout(.5)(features)features = L.Dense(1000, activation='relu')(features)features = L.Dropout(.5)(features)outputs = L.Dense(N_CLASSES, activation='softmax', name='outputs', dtype='float32')(features)model = Model(inputs=inputs, outputs=outputs)return model

监督对比学习损失

这是SCL损失的代码实现,这里唯一的参数是temperature,“0.1”是默认值,但它可以调整,较大的temperatures可以导致类更分离,但较小的temperatures 有益于较长的训练。

class SupervisedContrastiveLoss(losses.Loss):def __init__(self, temperature=0.1, name=None):super(SupervisedContrastiveLoss, self).__init__(name=name)self.temperature = temperaturedef __call__(self, labels, ft_vectors, sample_weight=None):# Normalize feature vectorsft_vec_normalized = tf.math.l2_normalize(ft_vectors, axis=1)# Compute logitslogits = tf.divide(tf.matmul(ft_vec_normalized, tf.transpose(ft_vec_normalized)), temperature)return tfa.losses.npairs_loss(tf.squeeze(labels), logits)

训练

我将跳过Tensorflow样板训练代码,因为它非常标准,但是你可以在这里:https://www.kaggle.com/dimitreoliveira/cassava-leaf-supervised-contrastive-learning/notebook#Training-(supervised-contrastive-learning查看完整的代码。

第一个训练步骤 (编码器 + 投影头)

第一阶段的训练是用编码器+投影头,使用有监督对比学习损失。

构建模型

with strategy.scope(): # Inside a strategy because I am using a TPUencoder = encoder_fn((None, None, CHANNELS)) # Get the encoderencoder_proj = add_projection_head((None, None, CHANNELS),encoder)# Add the projection head to the encoderencoder_proj.compile(optimizer=optimizers.Adam(lr=3e-4), loss=SupervisedContrastiveLoss(temperature=0.1))

训练

model.fit(x=get_dataset(TRAIN_FILENAMES, repeated=True, augment=True), validation_data=get_dataset(VALID_FILENAMES, ordered=True), steps_per_epoch=100, epochs=10)

第二个训练步骤 (编码器 + 分类头)

对于训练的第二阶段,我们删除投影头,并在编码器的顶部添加分类器头,现在该编码器已经训练了权值。对于这一步,我们可以使用常规的交叉熵损失,像往常一样训练模型。

构建模型

with strategy.scope():model = classifier_fn((None, None, CHANNELS), N_CLASSES, encoder, # trained encodertrainable=False) # with frozen weights    model.compile(optimizer=optimizers.Adam(lr=3e-4),loss=losses.SparseCategoricalCrossentropy(), metrics=[metrics.SparseCategoricalAccuracy()])

训练

和之前几乎一样

model.fit(x=get_dataset(TRAIN_FILENAMES, repeated=True, augment=True), validation_data=get_dataset(VALID_FILENAMES, ordered=True), steps_per_epoch=100, epochs=10)

可视化输出向量

评估编码器的学习表示的一种有趣的方法是可视化特征嵌入的输出,在我们的例子中,它是编码器的最后一层,即平均池化层。在这里,我们将比较用SCL训练的模型和另一个用常规交叉熵训练的模型,你可以在:https://www.kaggle.com/dimitreoliveira/cassava-leaf-supervised-contrastive-learning中看到完整的训练。可视化是通过在验证数据集的嵌入输出上应用t-SNE生成的。

交叉熵的嵌入

对使用交叉熵训练的模型嵌入进行可视化

监督对比学习的嵌入

使用SCL训练出的模型的嵌入的可视化。

我们可以看到,两种模型在对每个类进行样本聚类的时候似乎都可以做的不错,但看下SCL模型训练出来的嵌入,每个类的簇相互之间的距离要更远,这就是对比学习的效果。我们也可以认为,这种行为将导致更好的泛化,因为类的判别边界会更清晰、如果去尝试画一下类别之间的边界,就可以得到一个很直观的理解。

总结

我们看到,使用监督对比学习方法的训练既容易实现又有效,它可以带来更好的准确性和更好的类表示,这反过来也可以产生更健壮的模型,能够更好地泛化。

—END—

英文原文:https://pub.towardsai.net/supervised-contrastive-learning-for-cassava-leaf-disease-classification-9dd47779a966

请长按或扫描二维码关注本公众号

喜欢的话,请给我个在看吧

这篇关于有监督对比学习的一个简单的例子的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1080662

相关文章

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

一份LLM资源清单围观技术大佬的日常;手把手教你在美国搭建「百万卡」AI数据中心;为啥大模型做不好简单的数学计算? | ShowMeAI日报

👀日报&周刊合集 | 🎡ShowMeAI官网 | 🧡 点赞关注评论拜托啦! 1. 为啥大模型做不好简单的数学计算?从大模型高考数学成绩不及格说起 司南评测体系 OpenCompass 选取 7 个大模型 (6 个开源模型+ GPT-4o),组织参与了 2024 年高考「新课标I卷」的语文、数学、英语考试,然后由经验丰富的判卷老师评判得分。 结果如上图所

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

《offer来了》第二章学习笔记

1.集合 Java四种集合:List、Queue、Set和Map 1.1.List:可重复 有序的Collection ArrayList: 基于数组实现,增删慢,查询快,线程不安全 Vector: 基于数组实现,增删慢,查询快,线程安全 LinkedList: 基于双向链实现,增删快,查询慢,线程不安全 1.2.Queue:队列 ArrayBlockingQueue:

十五.各设计模式总结与对比

1.各设计模式总结与对比 1.1.课程目标 1、 简要分析GoF 23种设计模式和设计原则,做整体认知。 2、 剖析Spirng的编程思想,启发思维,为之后深入学习Spring做铺垫。 3、 了解各设计模式之间的关联,解决设计模式混淆的问题。 1.2.内容定位 1、 掌握设计模式的"道" ,而不只是"术" 2、 道可道非常道,滴水石穿非一日之功,做好长期修炼的准备。 3、 不要为了

硬件基础知识——自学习梳理

计算机存储分为闪存和永久性存储。 硬盘(永久存储)主要分为机械磁盘和固态硬盘。 机械磁盘主要靠磁颗粒的正负极方向来存储0或1,且机械磁盘没有使用寿命。 固态硬盘就有使用寿命了,大概支持30w次的读写操作。 闪存使用的是电容进行存储,断电数据就没了。 器件之间传输bit数据在总线上是一个一个传输的,因为通过电压传输(电流不稳定),但是电压属于电势能,所以可以叠加互相干扰,这也就是硬盘,U盘