《Prototypical Networks for Few-shot Learning 》论文概述

2024-01-12 23:48

本文主要是介绍《Prototypical Networks for Few-shot Learning 》论文概述,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

摘要:我们为小样本分类问题提出了原型网络,其中分类器能够很好的泛化到其他没有在训练集中出现的新类别,对于每一种新出现的类别,我们只给出很少的样本。原型网络学习一个度量空间,在该空间中,可以通过计算到每个类的原型表示的距离来执行分类。与最近的小样本学习方法相比,它们反映了一种更简单的归纳偏差,有利于在这种有限的数据范围内使用,并取得优异的效果。我们提供了一个分析,表明一些简单的设计决策比最近涉及复杂体系结构选择和元学习的方法可以产生实质性的改进。我们进一步将原型网络扩展到零样本学习,并在CU Birds数据集上获得最新结果。

一、介绍

1、简单介绍了匹配网络和元学习

匹配网络:它应用注意机制的学习嵌入处理标记好的样本(支持集)来预测未标记的点(查询集)的类别。匹配网络可以被解释为应用于嵌入空间中的加权最邻近分类器。关键词:Episode

元学习的方法用于小样本学习。他们的方法包括在一个给定Eisode的情况下训练一个LSTM去去更新一个分类器,这样就会很好的推广到测试集。在这里,LSTM元学习不是在多个Episodes上训练单个模型,而是学习为每一个Episode训练一个合适的模型。

2、概述原型网络的基本思想:基于集群,找到类的原型,找到合适距离度量方式进行分类。

二、原型网络

2.1符号说明

1、对一些基本符号进行说明

2.2模型

1、介绍了原型的计算方法(平均值)

对测试类x的分类方法:

2、训练集的构成:训练集是通过从训练集中随机选择一个类的子集,然后在每个类中选择一个示例子集作为支持集,其余的子集作为查询点而形成的。

3、损失的计算方法

有以下概念:

N:训练集中样例的数量

K:训练集中类的数量

NC:每个Episode中类别的数量

NS:每个类中支持样例的数量

NQ:每个类中查询样例的数量

RandomSample(S,N):denotes a set of N elements chosen uniformly at  random from set S, without replacement.

计算过程:为Episode选择类别-》选择支持集-》选择训练集-》计算支持集的原型-》初始化损失-》更新损失

2.3 作为混合密度估计的原型网络

正则Bregman散度

未标记点z的分布赋值y的推断变为:

2.4重新解释为线性模型

将计算公式进行解释及变换转化为线性模型,使用欧式距离。

2.5与匹配网络比较

原型网络和匹配网络在 few-shot场景下是不同的,但是在one-shot场景下是等价的。匹配网络在给定支持集的情况下产生加权最近邻分类器,而原型网络在使用平方欧氏距离时产生线性分类器。在一次学习的情况下,ck=xk,因为每个类只有一个支持点,匹配网络和原型网络变得等价。

2.6设计的选择

距离度量选择:欧式距离

Episode composition:使用比测试时更高的Nc或“way”进行训练是非常有益的。在我们的实验中,我们调整训练Nc在一个保持有消极上。另一个考虑因素是在训练和测试时间是匹配Ns还是‘shot’。对于原型网络,我们发现最好的是使用相同的‘shot’进行训练和测试。

2.7 Zero-Shot Learning

将原型网络应用于zero-shot Learning,Zero-shot与few-shot的不同之处在于,在没有给出训练点的支持集的情况下,我们给出了每个类的类元数据vk。这些可以实现确定,也可以从原始文本中学习。修改原型网络去处理zero-shot问题是很简单那的,我们定义为元数据向量的单独嵌入。由于元数据向量和查询点来自不同的输入域,我们发现将原型嵌入g固定为单位长度是有帮助的,但是我们不限制查询嵌入f。

三、实验

3.1 Omniglot Few-shot Classifification

数据集介绍:Omniglot是一个从50个字母表中收集的1623个手写字符的数据集。每一个字符有20个样例,每一个样例都是由不同的人绘制的。

实验设置:我们将灰度图像调整为28×28并且通过旋转90度的倍数来增加character classes。我们用1200个字符加上其旋转作为训练,用剩下的类包括旋转的用作测试。我们的嵌入架构由四个卷积块组成。每一块包括64个3×3的卷积滤波器,批处理规范化层,ReLu非线性层和2×2的最大池化层,当应用于28×28的Omniglot图像时,这种结构产生64维的输出空间。我们对嵌入支持点和查询点使用相同的编码器。我们的模型都用Adam进行训练。我们初始化学习速率为10e-3,并且每2000Episode减少一半的学习率。除了批处理规范化,我们没有使用正则化,我们使用欧几里德距离在1-shot和5-shot场景中训练原型网络,训练集包含60个类和每个类5个查询点。

结论:我们发现将训练镜头和测试镜头进行匹配是有利的,并且每个训练片段使用更多的类(higher way)而不是使用更少。我们比较了各种基线(baselines),包括 neural statistician和匹配网络的微调和非微调的版本。我们计算了我们的模型的分类精度,平均超过1000个随机从测试集产生的Episode。

3.2 miniImageNet Few-shot Classifification

数据集介绍:miniImageNet数据集一开始由Vinyals等人提出,是从较大的ILSVRC-12数据集分离出来的。miniImageNet包括60000张大小为84×84的彩色图片,图片被分成100个类并且每个类有600个样例。

实验设置:在我们的实验中,我们使用由Ravi和Larochelle引入的分离,以便直接与小样本学习中最先进的算法进行比较。他们的分组使用了一组不同的100个类,分为64个训练类、16个验证类和20个测试类。我们遵循他们的程序,在64个训练类上进行训练,并使用16个验证类来检查泛化能力。

我们使用与Omniglot实验相同的四块嵌入架构,但由于图像尺寸的增加,这里的输出空间为1600维。我们还使用与Omniglot实验相同的学习率计划,并进行训练,直到验证损失停止改善。We train using 30-way episodes for 1-shot classification and 20-way episodes for 5-shot classification.我们匹配train shot和test shot,并且每个Episode每个类包括15个查询点。

结论:我们比较了Ravi和Larochelle报告的baselines,其中包括一个简单的最邻近方法,该方法基于64个训练类上分类网络所学习的特征。另一个基线是匹配网络(普通和FCE)和元学习者LSTM的两个非微调变体。如表2所示,典型网络在这方面达到了最先进的水平。

我们进行了进一步的分析,以确定距离度量和每Episode中训练classes的数量对原型网络和匹配网络的影响。为了使这些方法更具有可比性,我们使用我们自己的匹配网络实现,它使用与我们的原型网络相同的嵌入架构。在图2中,我们比较了余弦距离与欧式距离,5-way和20-way  training episodes在1-shot和5-shot场景中,每个Episode每个类中有15个查询点。 我们注意到20-way比5-way获得了更高的准确率,并且推测20-way分类难度的增加有助于网络更好的泛化,因为它迫使模型在嵌入空间中做出更细粒度的决策。此外,使用欧氏距离比预先距离大大提高了性能。这种效果对于原型网络更为明显,在这种网络中,将类原型作为嵌入支持点的平均值进行计算更适合于欧氏距离,因为余弦距离不是Bregman散度。

3.3 CUB Zero-shot Classifification

数据集介绍:为了评估我们的方法对zero-shot学习的适用性,我们在Caltech-UCSD Birds (CUB) 200-2011 数据集上进行了实验。CUB数据集包括11788张200种鸟类。

实验设置:我们将类划分为100个训练集,50个验证集,50个测试集。对于图像,我们使用通过对原始和水平翻转图像的中间、左上、右上、左下和右下裁剪应用GoogLeNet[28]提取的1024维特征。在测试时我们只使用原始图像的中间部分,对于类元数据,我们使用CUB数据集提供的312维连续属性向量。这些属性编码鸟类的各种特征,如颜色、形状和羽毛图案。我们在1024维图像特征和312维属性向量的基础上学习了一个简单的线性映射来生成1024维输出空间。对于这个数据集,我们发现将类原型(嵌入的属性向量)规范化为单位长度很有帮助,因为属性向量来自与图像不同的域。训练Episode由每个类的50个classes和10个query图片组成。

在固定学习速率为10e-4和weight decay10-5的情况下,通过与Adam的SGD优化embedding。

实验结论:Early stopping on validation loss was used to determine the optimal number of epochs for retraining on the training plus validation set.表3显示,与使用属性作为类元数据的方法相比,我们可以获得更大幅度的最新结果。我们将我们的方法与其他嵌入方法进行比较,例如ALE、SJE和DS-SJE/DA-SJE。我们还比较了最近的聚类方法,该方法在通过微调AlexNet获得的学习特征空间上训练支持向量机。这些zero-shot结果表明,即使数据点(图像)来自与类(属性)相关的不同域,我们的方法也足够通用。

四、相关工作

1、Neighborhood Components Analysis (NCA):学习Mahalanobis distance 以最大限度提高knn在变换空间种的leave-one-out accuracy 。

2、Large margin nearest neighbor 大边距最近邻(LMNN)分类:试图优化KNN的精度,但使用的hinge loss铰链损失鼓励一个点的局部邻域包含具有相同标签的其他点。

3、DNet KNN:是另一种基于边距的方法,它通过使用神经网络来执行嵌入而不是简单的线性变换来改进LMNN。

其中,我们的方法与NCA[27]的非线性扩展最为相似,因为我们使用神经网络来执行嵌入,并且我们基于变换空间中的欧氏距离来优化softmax,而不是margin loss。我们的方法和非线性NCA之间的一个关键区别是,我们直接在类上而不是单个点上形成一个softmax,它是根据到每个类的原型表示的距离来计算的。这使得每个类都有一个与数据点数量无关的简明表示,并且避免了存储整个支持集以进行预测的需要。

我们的方法也类似于nearest class mean approach(最近类平均方法),其中每个类都用其示例的平均值表示。这种方法是为了在不需要重新训练的情况下快速地将新类合并到分类器中而开发的,但是它依赖于线性嵌入,并且是为了处理新类附带大量示例的情况而设计的。相反,我们的方法是对非线性嵌入点使用神经网络,并且我们将其与Episode training结合起来去处理few-shot 场景。

Mensink等人尝试扩展它们的方法来执行非线性分类,但是通过允许类具有多个原型实现。他们通过在输入空间上使用k-均值在预处理步骤中找到这些原型,然后对其线性嵌入进行多模态变换。另一方面,原型网络以端到端的方式学习非线性嵌入,而不需要这样的预处理,生成的非线性分类器仍然只需要每个类一个原型。此外,我们的方法自然地推广到其他距离函数,特别是Bregman divergences。

另一种和few-shot学习相关的是Ravi和Larochelle提出的元学习方法。这个的关键是 LSTM dynamics and gradient descent can be written in effectively the same way. 。LSTM可以被训练为自己从给定的Episode中训练一个模型,其性能目标是在查询点上很好地泛化。匹配网络和原型网络也可以看作元学习的形式,因为它们从新的训练片段中动态地生成简单的分类器;然而,他们所依赖的 core embeddings核心嵌入是在训练后固定的。匹配网络的FCE扩展涉及依赖于支持集的二次嵌入。然而,在少数镜头场景中,数据量非常小,一个简单的归纳偏差似乎很有效,无需为每集学习自定义嵌入。

Prototypical networks are also related to the neural statistician from the generative modeling literature, which extends the variational autoencoder to learn generative models of datasets rather than individual points.  neural statistician的一个组成部分是“统计网络”,它将一组数据点归纳为一个统计向量。它通过对数据集中的每个点进行编码,取一个样本均值,

并应用后处理网络获得统计向量上的近似后验。Edwards and Storkey  test their model for one-shot classification on the Omniglot dataset by considering each character to be a separate dataset and making predictions based on the class whose approximate posterior over the statistic vector has minimal KL-divergence from the posterior inferred by the test point. 像neural statistician一样,我们也为每一个类产生一个汇总统计。然而,我们的模型是一个判别模型,适合于我们进行few-shot分类的判别任务。

关于zero-shot学习,在原型网络中使用嵌入元数据类似于之前的方法,因为两者都预测线性分类器的权重。DS-SJE和DA-SJE方法还学习了图像和类元数据的深度多模态嵌入函数。与我们不同,他们学习使用经验风险损失。[3]和[23]都没有使用阶段性训练,这使得我们能够帮助加快训练并使模型正规化。

五、总结

我们提出了一种简单的few-shot学习的方法称作原型网络,其基本思想是,在一个由神经网络学习的表示空间中用样例的平均值来表示每一类。我们通过使用episode训练使得神经网络在few-shot学习中表现的特别好。这种方法比元学习简单并且更有效,即便没有匹配网络进行复杂的拓展也能产生最新的结果(尽管这些方法也可以应用于原型网络)。我们展示了如何通过仔细考虑所选择的距离度量,并通过修改Episode学习过程来大大提高性能。我们进一步展示了如何将原型网络推广到zero-shot setting,并且在CUB-200数据集上实现了最新的结果。未来工作的一个自然方向是利用Bregman发散,而不是平方欧氏距离,对应于超越球面高斯的类条件分布。我们对此进行了初步的探索,包括为一个类学习每个维度的方差。这并没有导致任何经验收益,这表明嵌入网络本身具有足够的灵活性,而不需要每个类的附加拟合参数。总的来说,原型网络的简单性和有效性使其成为一种有前途的few-shot学习方法。

 

 

 

 

这篇关于《Prototypical Networks for Few-shot Learning 》论文概述的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/599563

相关文章

水位雨量在线监测系统概述及应用介绍

在当今社会,随着科技的飞速发展,各种智能监测系统已成为保障公共安全、促进资源管理和环境保护的重要工具。其中,水位雨量在线监测系统作为自然灾害预警、水资源管理及水利工程运行的关键技术,其重要性不言而喻。 一、水位雨量在线监测系统的基本原理 水位雨量在线监测系统主要由数据采集单元、数据传输网络、数据处理中心及用户终端四大部分构成,形成了一个完整的闭环系统。 数据采集单元:这是系统的“眼睛”,

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

Java 创建图形用户界面(GUI)入门指南(Swing库 JFrame 类)概述

概述 基本概念 Java Swing 的架构 Java Swing 是一个为 Java 设计的 GUI 工具包,是 JAVA 基础类的一部分,基于 Java AWT 构建,提供了一系列轻量级、可定制的图形用户界面(GUI)组件。 与 AWT 相比,Swing 提供了许多比 AWT 更好的屏幕显示元素,更加灵活和可定制,具有更好的跨平台性能。 组件和容器 Java Swing 提供了许多

【编程底层思考】垃圾收集机制,GC算法,垃圾收集器类型概述

Java的垃圾收集(Garbage Collection,GC)机制是Java语言的一大特色,它负责自动管理内存的回收,释放不再使用的对象所占用的内存。以下是对Java垃圾收集机制的详细介绍: 一、垃圾收集机制概述: 对象存活判断:垃圾收集器定期检查堆内存中的对象,判断哪些对象是“垃圾”,即不再被任何引用链直接或间接引用的对象。内存回收:将判断为垃圾的对象占用的内存进行回收,以便重新使用。

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

Java 多线程概述

多线程技术概述   1.线程与进程 进程:内存中运行的应用程序,每个进程都拥有一个独立的内存空间。线程:是进程中的一个执行路径,共享一个内存空间,线程之间可以自由切换、并发执行,一个进程最少有一个线程,线程实际数是在进程基础之上的进一步划分,一个进程启动之后,进程之中的若干执行路径又可以划分成若干个线程 2.线程的调度 分时调度:所有线程轮流使用CPU的使用权,平均分配时间抢占式调度

论文翻译:ICLR-2024 PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS

PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS https://openreview.net/forum?id=KS8mIvetg2 验证测试集污染在黑盒语言模型中 文章目录 验证测试集污染在黑盒语言模型中摘要1 引言 摘要 大型语言模型是在大量互联网数据上训练的,这引发了人们的担忧和猜测,即它们可能已

OmniGlue论文详解(特征匹配)

OmniGlue论文详解(特征匹配) 摘要1. 引言2. 相关工作2.1. 广义局部特征匹配2.2. 稀疏可学习匹配2.3. 半稠密可学习匹配2.4. 与其他图像表示匹配 3. OmniGlue3.1. 模型概述3.2. OmniGlue 细节3.2.1. 特征提取3.2.2. 利用DINOv2构建图形。3.2.3. 信息传播与新的指导3.2.4. 匹配层和损失函数3.2.5. 与Super

BERT 论文逐段精读【论文精读】

BERT: 近 3 年 NLP 最火 CV: 大数据集上的训练好的 NN 模型,提升 CV 任务的性能 —— ImageNet 的 CNN 模型 NLP: BERT 简化了 NLP 任务的训练,提升了 NLP 任务的性能 BERT 如何站在巨人的肩膀上的?使用了哪些 NLP 已有的技术和思想?哪些是 BERT 的创新? 1标题 + 作者 BERT: Pre-trainin