用少量样本就能训练神经网络

2024-03-18 19:20

本文主要是介绍用少量样本就能训练神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

朋友们,如需转载请标明出处:人工智能AI技术的博客_CSDN博客-python系列教程,人工智能,程序人生领域博主

在大多数时候,你是没有足够的图像来训练深度神经网络的。本文将教你如何从小样本数据快速学习你的模型。

  为什么我们关心小样本学习?

1980年, Kunihiko Fukushima 提出了第一个卷积神经网络。从那时起,由于计算能力的不断提高和机器学习社区的巨大努力,深度学习算法在与计算机视觉相关的任务上从未停止过提高它们的性能。2015年,何凯明和他在微软的团队报告说,他们的模型在对来自 ImageNet 的图像进行分类时表现优于人类。在那时候,有人可能会说,计算机在利用数十亿图像来解决特定任务方面变得比我们更强。欢呼!

然而,如果你不是 Google 或者 Facebook,你就不可能总是能够用这么多的图像来构建一个数据集。当您从事计算机视觉工作时,有时您必须对每个标签只有一个或两个样本的图像进行分类。在这场比赛中,人类仍将被打败。给婴儿看一张大象的照片,从现在起他们永远不会认不出大象。如果你对 Resnet50 做同样的事情,你可能会对结果感到失望。从少数的样本中学习的这个问题,被称为小样本学习(few-shot learning)。

近几年来,小样本学习的问题引起了研究界的广泛关注,并形成了许多优雅的解决方案。目前最流行的解决方案是使用元学习(meta-learning),又称为:learning to learn。如果你想知道它是什么,以及它是如何用于小样本图像分类,请继续阅读。

  极少样本的分类任务

首先,我们需要定义N个类别,K张图片(译者注:针对每个类别)的分类任务。假设以下的场景:

1. 一个支持数据集,包含N个分类标签,针对每个标签有K个已分类的图片。

2. 一个查询数据集,包含Q张查询图片。

任务是利用支持数据集中的N*K个图片,将查询数据集中的图片分类为N个类别(译者注:可以理解为训练集有N*K个图片,将测试集在N个类别进行分类)。当K值很小时(通常K<10),我们称这种分类任务为极少样本分类任务(当K=1时,变成单样本分类任务)

图像样本不够用?元学习帮你解决

极少样本分类任务的一个例子:在支持集中,给定N=3(3类),每类K=2,即每种类别两张图片,我们希望将查询集中(查询集Q=4,即4张查询图片)的狗标注为拉普拉多狗,圣伯纳德狗或哈巴狗。即使你从未见过任何的哈巴狗、圣伯纳德狗或拉普拉多狗,这项任务对你来说也不困难。但使用AI来解决这个问题,我们需要进行一些元学习。

  元学习范例

1998年,Thrun和Pratt指出,对于一个指定的任务,一个算法“如果随着经验的增长,在该任务上的表现得到改进”,则认为该算法能够学习。与此同时,与此同时,对于一族待解决的多个任务,一个算法“如果随着经验和任务数量的增长,在每个任务上的表现得到改进”,则认为该算法能够学习如何学习,我们将后者称为元学习算法。它不学习如何解决一个特定的问题,但可以成功学习如何解决多个任务。每当它学会解决一个新的任务,它就越有能力解决其他新的任务:它学会如何学习。

如果我们希望解决一项任务T,会在一批训练任务{Ti}上训练元学习算法。算法在被训练解决这些任务的过程中得到的经验将被用于解决最终的任务T。

比如,考虑上个图像中提到的任务T。它的目标是通过使用3x2=6张已标记的同品种狗的图片,来识别(新的)图片是属于拉普拉多狗,圣伯纳德狗或哈巴狗。训练任务{Ti}中的某一项任务Ti可以是通过使用3x2=6的已标记的同品种狗图片中获取信息,将新图片标记为拳师狗、圣伯纳德狗或洛特维勒牧狗。元学习过程就是由一系列这样的每一次针对不同品种的狗的训练任务Ti所组成的。我们希望元学习模型“随着经验和任务数量的增长”得到不断地改进。最终,我们在T任务上评估模型。

图像样本不够用?元学习帮你解决

我们评估了拉布拉多犬、圣伯纳德犬和八哥的元学习模型,但我们只是在其他所有品种上进行训练。 

现在我们该怎么做?假设你想解决任务T(里面有拉布拉多,圣伯纳德和 八哥),那么你需要一个元训练数据集,里面有很多不同品种的狗。 你可以使用 Stanford Dogs 数据集(http://vision.stanford.edu/aditya86/ImageNetDogs),其中包含从ImageNet中提取的超过20k 只狗。我们将把这个数据集命名为D。注意,这个过程不需要包含任何拉布拉多、圣伯纳德或八哥。 

我们从D中采样了一批(如下),每集对应于一个 N-way K-shot 分类任务 Tᵢ 类似T(通常我们使用相同的N和K)。 模型解决了每一集(即标记了每一个查询集的图像)后,它的参数会更新,这通常是通过对查询集的分类不准确造成的损失进行反向跟踪来实现的。

这样,模型就可以跨任务学习准确地解决一个新的、不可见的少镜头分类任务。 标准的学习分类算法学习映射图像→标签,元学习算法通常学习映射支持集→c(.),其中c是映射查询→标签。

  元学习算法 

既然我们知道了算法元训练的含义,那么还有一个谜团:元学习模型是如何解决一个少镜头分类任务的?当然,解决方案不止一个。在这里,我们将关注最受欢迎的方案。  

  度量学习

度量学习的基本思想是学习数据点(如图像)之间的距离函数。事实证明,它对于解决少样本分类任务非常有用:度量学习算法不必在支持集(少量的带标签图像)上进行微调,而是通过将查询图像与带标签图像进行比较来对其进行分类。

图像样本不够用?元学习帮你解决

将查询图像(在右侧)与支持集的每个图像进行比较,它的标签取决于与其最接近的图像。当然,你不能逐个像素地比较图像,你要做的是在相关的特征空间中比较图像。为了清楚起见,让我们详细说明度量学习算法是如何解决少样本分类任务的(上面定义为带标签样本的支持集,以及我们要分类的查询图像集):

  1. 我们从支持集和查询集的所有图像中提取特征(通常使用卷积神经网络)。现在,我们在少样本分类任务中必须考虑的每个图像都由一个一维向量表示。

  2. 每个查询图像根据其与支持集图像的距离进行分类。对于距离函数和分类策略,可以有许多可能的设计选择。例如,欧氏距离和k-最近邻分类。

  3. 在元训练期间,在每一场景(episode)结束时,对由查询集的分类错误产生的损失值(通常是交叉熵损失)进行反向传播,从而更新CNN的参数。

每年都会提出几种度量学习算法来解决少样本图像分类问题,这其中的两个原因是:

  1. 他们凭经验可以做得很好;

  2. 唯一的限制就是你的想象力。有很多方法可以提取特性,甚至还有更多方法可以比较这些特性。我们现在将介绍一些现有的解决方案。

图像样本不够用?元学习帮你解决

匹配网络算法。对于支持集图像(左)和查询图像(底部),特征提取器是不同的。使用余弦相似性,将查询的嵌入特征与支持集中的每个图像进行比较。然后用softmax进行分类。上图来自Oriol等。

匹配网络(见上文)是第一个使用元学习的度量学习算法。在这个方法中,我们不会以同样的方式提取支持图像和查询图像的特征。来自 Google DeepMind 的Oriol Vinyals和他的团队有一个想法,即在特征提取过程中使用LSTM网络使所有图像交互。他们称之为完全上下文嵌入,因为你允许网络找到最合适的嵌入,这不仅知道要嵌入的图像,还知道支持集中的所有其他图像。这使得他们的模型比所有的图像都通过一个简单的CNN时表现得更好,但它也需要更多的时间和更大的GPU。

在最近的工作中,我们不会将查询图像与支持集中的每个图像进行比较。多伦多大学的研究人员提出了原型网络。在他们的度量学习算法中,从图像中提取特征后,我们为每个类计算一个原型。为此,他们使用类中每个图像嵌入的平均值。(但是你可以想出成千上万的方法来计算这些嵌入。为了反向传播,函数只需要是可微的。)一旦计算出原型,就可以计算查询图像到原型的欧式距离,从而对查询图像进行分类(见下图)。

图像样本不够用?元学习帮你解决

在原型网络中,我们将查询X标记为与其最接近的原型的标签。

尽管简单,但原型网络仍然可以产生最好的结果。更复杂的度量学习架构后来被开发出来,比如用神经网络来表示距离函数(而不是欧几里得距离)。这略微提高了准确性,但我相信到目前为止,原型理念在少样本图像分类的度量学习算法领域是最有价值的想法(如果你不同意,请留下评论)。

  模型无关的元学习

我们将以模型无关的元学习(MAML)结束这篇综述,MAML是目前最优雅和最有潜力的元学习算法之一。它基本上是最纯粹的元学习形式,通过神经网络进行两级反向传播。

该算法的核心思想是训练一个神经网络,使其能够仅用少量样本就能快速适应新的分类任务。下图将展示MAML如何在元训练的一个场景(即,从数据集D中采样得到的少样本分类任务Tᵢ)中工作的。假设你有一个用?参数化的神经网络M:

图像样本不够用?元学习帮你解决

用?参数化的MAML模型的元训练步骤:

  1. 创建M的副本(此处命名为f),并用?对其进行初始化(在图中,?₀=?)。

  2. 快速微调支持集上的f(只有少量梯度下降)。

  3. 在查询集上应用微调过的f。

  4. 在整个过程中,对分类错误造成的损失进行反向传播,并更新?。

然后,在下一场景中,我们创建一个更新后模型M的副本,我们在新的少样本分类任务上运行该过程,依此类推。

在元训练期间,MAML学习初始化参数,这些参数允许模型快速有效地适应新的少样本任务,其中这个任务有着新的、未知的类别。

公平地说,MAML目前在流行的少样本图像分类基准测试中的效果不如度量学习算法。由于训练分为两个层次,模型的训练难度很大,因此超参数搜索更为复杂。此外,元的反向传播意味着需要计算梯度的梯度,因此你必须使用近似值来在标准GPU上进行训练。出于这些原因,你可能更愿意在家里或工作中为你的项目使用度量学习算法。

但是,模型无关的元学习之所以如此令人兴奋,是因为它的模型是不可知的。这意味着它几乎可以应用于任何神经网络,适用于任何任务。掌握MAML意味着只需少量样本就能够训练任何神经网络以快速适应新的任务。MAML的作者Chelsea Finn和Sergey Levine将其应用于有监督的少样本分类,监督回归和强化学习。但是通过想象和努力研究,你可以用它把任何一个神经网络转换成少样本有效的神经网络!

这就是这次在元学习这个令人兴奋的世界里的旅行。少样本学习最近引起了计算机视觉研究的广泛关注,因此该领域的发展非常迅速(如果你在2020年阅读这篇文章,我建议你寻找更新的信息来源)。谁知道未来几年,神经网络会变得有多好,是否能一眼就学习到视觉概念?

这篇关于用少量样本就能训练神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/823378

相关文章

图神经网络模型介绍(1)

我们将图神经网络分为基于谱域的模型和基于空域的模型,并按照发展顺序详解每个类别中的重要模型。 1.1基于谱域的图神经网络         谱域上的图卷积在图学习迈向深度学习的发展历程中起到了关键的作用。本节主要介绍三个具有代表性的谱域图神经网络:谱图卷积网络、切比雪夫网络和图卷积网络。 (1)谱图卷积网络 卷积定理:函数卷积的傅里叶变换是函数傅里叶变换的乘积,即F{f*g}

MiniGPT-3D, 首个高效的3D点云大语言模型,仅需一张RTX3090显卡,训练一天时间,已开源

项目主页:https://tangyuan96.github.io/minigpt_3d_project_page/ 代码:https://github.com/TangYuan96/MiniGPT-3D 论文:https://arxiv.org/pdf/2405.01413 MiniGPT-3D在多个任务上取得了SoTA,被ACM MM2024接收,只拥有47.8M的可训练参数,在一张RTX

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering)

Spark MLlib模型训练—聚类算法 PIC(Power Iteration Clustering) Power Iteration Clustering (PIC) 是一种基于图的聚类算法,用于在大规模数据集上进行高效的社区检测。PIC 算法的核心思想是通过迭代图的幂运算来发现数据中的潜在簇。该算法适用于处理大规模图数据,特别是在社交网络分析、推荐系统和生物信息学等领域具有广泛应用。Spa

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录

Detectorn2预训练模型复现:数据准备、训练命令、日志分析与输出目录 在深度学习项目中,目标检测是一项重要的任务。本文将详细介绍如何使用Detectron2进行目标检测模型的复现训练,涵盖训练数据准备、训练命令、训练日志分析、训练指标以及训练输出目录的各个文件及其作用。特别地,我们将演示在训练过程中出现中断后,如何使用 resume 功能继续训练,并将我们复现的模型与Model Zoo中的

机器学习之监督学习(三)神经网络

机器学习之监督学习(三)神经网络基础 0. 文章传送1. 深度学习 Deep Learning深度学习的关键特点深度学习VS传统机器学习 2. 生物神经网络 Biological Neural Network3. 神经网络模型基本结构模块一:TensorFlow搭建神经网络 4. 反向传播梯度下降 Back Propagation Gradient Descent模块二:激活函数 activ

多云架构下大模型训练的存储稳定性探索

一、多云架构与大模型训练的融合 (一)多云架构的优势与挑战 多云架构为大模型训练带来了诸多优势。首先,资源灵活性显著提高,不同的云平台可以提供不同类型的计算资源和存储服务,满足大模型训练在不同阶段的需求。例如,某些云平台可能在 GPU 计算资源上具有优势,而另一些则在存储成本或性能上表现出色,企业可以根据实际情况进行选择和组合。其次,扩展性得以增强,当大模型的规模不断扩大时,单一云平

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04

基于深度学习 卷积神经网络resnext50的中医舌苔分类系统

项目概述 本项目旨在通过深度学习技术,特别是利用卷积神经网络(Convolutional Neural Networks, CNNs)中的ResNeXt50架构,实现对中医舌象图像的自动分类。该系统不仅能够识别不同的舌苔类型,还能够在PyQt5框架下提供一个直观的图形用户界面(GUI),使得医生或患者能够方便地上传舌象照片并获取分析结果。 技术栈 深度学习框架:采用PyTorch或其他

图神经网络(2)预备知识

1. 图的基本概念         对于接触过数据结构和算法的读者来说,图并不是一个陌生的概念。一个图由一些顶点也称为节点和连接这些顶点的边组成。给定一个图G=(V,E),  其 中V={V1,V2,…,Vn}  是一个具有 n 个顶点的集合。 1.1邻接矩阵         我们用邻接矩阵A∈Rn×n表示顶点之间的连接关系。 如果顶点 vi和vj之间有连接,就表示(vi,vj)  组成了