迁移学习之领域自适应(domain adaptation)

2024-09-01 01:36

本文主要是介绍迁移学习之领域自适应(domain adaptation),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        比如有一堆有标注的训练数据,这些数 据来自源领域,用这些数据训练出一个模型,这个模型可以用在不一样的领域。在训练的时 候,我们必须要对测试数据所在的目标领域有一些了解。

        随着了解的程度不同,领域自适应的方法也不同。如果目标领域上有一大堆有标签的数 据,这种情况其实不需要做领域自适应,直接用目标领域的数据训练。如果目标领域上有一点 有标签的数据,这种情况可以用领域自适应,可以用这些有标注的数据微调在源领域上训练 出来的模型。这边的微调跟BERT的微调很像,已经有一个在源领域上训练好的模型,只要 拿目标领域的数据跑个两、三个回合就足够了。在这一种情况下,需要注意的问题是,因为目 标领域的数据量非常少,所以要小心不要过拟合,不要在目标领域的数据上迭代太多次。在目 标数据上迭代太多次,可能会过拟合到目标领域的少量数据上,在真正的测试集的表现可能 就不好。(为了避免过拟合的情况,有很多的解决方法,比如调小一点学习率。要让微调前、后的 模型的参数不要差很多,或者让微调前、后的模型的输入跟输出的关系,不要差很多等 等。)

        下面主要介绍下在目标领域上有大量未标注的数据的这种情况。这种情况其实是很符合 实际会发生的情况。比如在实验室里面训练了一个模型,并想要把它用在真实的场景里面,于 是将模型上线。上线后的模型确实有一些人来用,但得到的反馈很差,大家嫌弃系统正确率很 低。这种情况就可以用领域自适应的技术,因为系统已经上线后会有人使用,就可以收集到一 大堆未标注的数据。这些未标注的数据可以用在源领域上训练一个模型,并用在目标领域。

        最基本的想法如图1 所示,训练一个特征提取器(feature extractor)。特征提取器也 是一个网络,这个网络输入是一张图片,输出是一个特征向量。虽然源领域与目标领域的图像 不一样,但是特征提取器会把它们不一样的部分去除,只提取出它们共同的部分。虽然源领域 和目标领域的图片的颜色不同,但特征提取器可以学习到把颜色的信息滤掉,忽略颜色。源领 域和目标领域的图片通过特征提取器以后,其得到的特征是没有差异的,分布相同。通过特征提取器可以在源领域上训练一个模型,直接用在目标领域上。通过领域对抗训练(domain adversarial training)可以得到领域无关的表示。

图1 通过特征提取器过滤颜色信息

        一般的分类器可分成特征提取器和标签预测器(labelpredictor)两个部分。图像的分类 器输入一张图像,输出分类的结果。假设图像的分类器有10层,前5层是特征提取器,后5 层是标签预测器。前5层可看成特征提取器,一个图像通过前5层,其输出是一个向量;如 果使用卷积神经网络,其输出是特征映射,但特征映射“拉直”也可以看做是一个向量,该向量 再输入到后面5层(标签预测器)来产生类别。

        图2 给出了特征提取器和标签预测器的训练过程。对于源领域上标注的数据,把源领 域的数据“丢”进去,这跟训练一个一般的分类器一样,它通过特征提取器,再通过标签预测器, 可以产生正确的答案。但不一样的地方是,目标领域的一堆数据是没有任何标注的,把这些图 片“丢”到图像分类器,把特征提取器的输出拿出来看,希望源领域的图片“丢”进去的特征跟目 标领域的图片“丢”进去的特征相同。图2 中蓝色的点表示源领域图片的特征,红色的点表 示目标领域图片的特征,通过领域对抗训练让蓝色的点跟红色的点分不出差异。

图2 训练特征提取器,让源领域和目标领域的特征无差异

        如图3 所示,我们要训练一个领域分类器。领域分类器是一个二元的分类器,其输入 是特征提取器输出的向量,其目标是判断这个向量是来自于源领域还是目标领域,而特征提 取器学习的目标是要去想办法骗过领域分类器。领域对抗训练非常像是生成对抗网络,特征 提取器可看成生成器,领域分类器可看成判别器。但在领域对抗训练里面,特征提取器优势太 大了,其要骗过领域分类器很容易。比如特征提取器可以忽略输入,永远都输出一个零向量。 这样做领域分类器的输入都是零向量,其也无法判断该向量的领域。但标签预测器也需要特 征判断输入的图片的类别,如果特征提取器只会输出零向量,标签预测器无法判断是哪一张 图片。特征提取器还是需要产生向量来让标签预测器可以输出正确的预测。因此特征提取器 不能永远都输出零向量。

图3 领域对抗训练

        假设标签预测器的参数为θp,领域分类器的参数为θd,特征提取器的参数为θf。源领域的图像是有标签的,所以可以计算它们的交叉熵来得出损失L。领域分类器要想办法判断 图片是源领域还是目标领域,这就是一个二元分类的问题,该分类问题的损失为Ld。我们要 去找一个θp,它可以让L越小越好,即

我们要去找一个θd,它可以让这个Ld 越小越好,即 

        标签预测器要让源领域的图像分类越正确越好,领域分类器要让领域的分类越正确越好。 而特征提取器站在标签预测器这边,它要去做领域分类器相反的事情,所以特征提取器的损 失是标签预测器的损失L减掉领域分类器的损失Ld,所以特征提取器的损失是L−Ld,找 一组参数θf 让L−Ld 的值越小越好,即 

        假设领域分类器的工作是把源领域跟目标领域分开,根据特征提取器的特征,来判断数据是 来自源领域还是目标领域,把源领域和目标领域的两组特征分开。而特征提取器的损失中是 −Ld,这意味着它要做的事情跟领域分类器相反。如果领域分类器根据某张图片的特征判断这 张图片属于源领域,而特征提取器要让领域分类器根据这张图片的特征判断这张图片属于目 标领域,这样做也就可以分开源领域和目标领域的特征。本来领域分类器是让Ld的值越小越 好,特征提取器要让Ld的值越大越好,其目的都是分开源领域跟目标领域的特征。以上是最 原始的领域对抗训练做法。 

        领域对抗训练最原始的论文做了如图4所示的四个从源领域到目标领域的任务。如果 用目标领域的图片训练,目标领域的图片测试,结果如表 1所示,每一个任务正确率都是 90% 以上。但如果用源领域训练,目标领域测试,结果比较差。如果使用领域对抗训练,正 确率会有明显的提升。

图4 领域对抗训练最原始论文使用的任务

 

表1 不同源领域和目标领域的数字图像分类的准确率

        领域对抗训练最早的论文发表在2015 年的 ICML 上面,其比生成对抗网络还要稍微晚 一点,不过它们几乎可以是同时期的作品。

        刚才这整套想法,有一个小小的问题。用蓝色的圆圈和三角形表示源领域上的两个类别, 用正方形来表示目标领域上无类别标签的数据。可以找一个边界去把源领域上的两个类别分 开。训练的目标是要让正方形的分布跟圆圈、三角形合起来的分布越接近越好。在图5(a) 所示的情况中,红色的点跟蓝色的点是挺对齐在一起的。在图5 (b)所示的情况中,红 色的点跟蓝色的点是分布挺接近的。虽然正方形的类别是未知的,但蓝色的圆圈跟蓝色的三 角形的决策边界是已知的,应该让正方形远离决策边界。因此两种情况相比,我们更希望在 图5 (b)的情况发生,而避免让在图5(a)的状况发生。

图5 决策边界

        让正方形远离边界(boundary)最简单的做法如图13.8所示。把很多无标注的图片先“丢” 到特征提取器,再“丢”到标签预测器,如果输出的结果集中在某个类别上,就是离边界远;如 果输出的结果每一个类别非常地接近,就是离边界近。除了上述比较的简单的方法外,还可以 使用DIRT-T、最大分类器差异(maximum classifier discrepancy)等方法。这些方法在 领域自适应中是不可或缺的。

图6 离边界越远越好

        目前为止都假设源领域跟目标领域的类别都要是一模一样,比如图像分类,源领域有老 虎、狮子跟狗,目标领域也应该要有老虎、狮子跟狗,但实际上目标领域是没有标签的,其里面的类别是未知的。如图7所示,实线的椭圆圈代表源领域里面有的东西,虚线的椭圆圈 代表目标领域里面有的东西。图7(a)中源领域里面的东西比较多,目标领域里面的东西 比较少;图7(b)中源领域里面的东西比较少,目标领域的东西比较多。图7(c)中两 者虽然有交集,但是各自都有独特的类别。

图 7 强制完全对齐源领域跟目标领域的问题

        强制把源领域跟目标领域完全对齐在一起是有问题的,比如图13.9(c)里面,要让源领 域的数据跟目标领域的数据的特征完全匹配,这意味是要让老虎去变得跟狗像,或者让老虎 变得跟狮子像,这样老虎这个类别就不能区分了。源领域跟目标领域有不同标签问题的解决 方法,可参考论文“Universal Domain Adaptation”。

        但是有一个可能是目标领域的数据不仅没有标签,而且还很少,比如目标领域只有一张 图片,也就无法跟源领域对齐。这种情况可使用测试时训练(TestingTime Training,TTT) 方法,读者可参考论文“Test-Time Training with Self-Supervision for Generalization under Distribution Shifts”。

        如果特征提取器是卷积神经网络,而不是线性层(linear layer)。领域分类器输入 是特征映射,特征映射本来就有空间的关系。把两个领域“拉”在一起会不会有影响隐空 间(latent space),让隐空间没能学到本来希望它学到的东西? A:会有影响。领域自适应训练需要同时做好两个方面的事:一方面要骗领域分类器, 另一方面是要让分类变正确。即不仅要把两个领域对齐在一起,还要让隐空间的分布 是正确的。比如我们觉得1跟7比较像,为了要让分类器做好,特征提取器会让1跟 7 比较像。因为要提高标签预测器的性能,所以隐表示(latentrepresentation)里面的 空间仍然是一个比较好的隐空间。但如果给领域分类器就是要骗过领域分类器,这件 事情的权重太大。模型就会学到只想骗过领域分类器,它就不会产生好的隐空间。

这篇关于迁移学习之领域自适应(domain adaptation)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1125571

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学

系统架构师考试学习笔记第三篇——架构设计高级知识(20)通信系统架构设计理论与实践

本章知识考点:         第20课时主要学习通信系统架构设计的理论和工作中的实践。根据新版考试大纲,本课时知识点会涉及案例分析题(25分),而在历年考试中,案例题对该部分内容的考查并不多,虽在综合知识选择题目中经常考查,但分值也不高。本课时内容侧重于对知识点的记忆和理解,按照以往的出题规律,通信系统架构设计基础知识点多来源于教材内的基础网络设备、网络架构和教材外最新时事热点技术。本课时知识

线性代数|机器学习-P36在图中找聚类

文章目录 1. 常见图结构2. 谱聚类 感觉后面几节课的内容跨越太大,需要补充太多的知识点,教授讲得内容跨越较大,一般一节课的内容是书本上的一章节内容,所以看视频比较吃力,需要先预习课本内容后才能够很好的理解教授讲解的知识点。 1. 常见图结构 假设我们有如下图结构: Adjacency Matrix:行和列表示的是节点的位置,A[i,j]表示的第 i 个节点和第 j 个

Node.js学习记录(二)

目录 一、express 1、初识express 2、安装express 3、创建并启动web服务器 4、监听 GET&POST 请求、响应内容给客户端 5、获取URL中携带的查询参数 6、获取URL中动态参数 7、静态资源托管 二、工具nodemon 三、express路由 1、express中路由 2、路由的匹配 3、路由模块化 4、路由模块添加前缀 四、中间件