【论文阅读】Virtual Adversarial Training: a Regularization Method for SL and SSL

本文主要是介绍【论文阅读】Virtual Adversarial Training: a Regularization Method for SL and SSL,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

《Virtual Adversarial Training: a Regularization Method for Supervised and Semi-supervised Learning》

1. 摘要

We propose a new regularization method based on virtual adversarial loss: a new measure of local smoothness of the
output distribution. Virtual adversarial loss is defined as the robustness of the model’s posterior distribution against local perturbation
around each input data point. Our method is similar to adversarial training, but differs from adversarial training in that it determines the
adversarial direction based only on the output distribution and that it is applicable to a semi-supervised setting. Because the directions
in which we smooth the model are virtually adversarial, we call our method virtual adversarial training (VAT).

我们提出了一种基于虚拟对抗损失的新正则化方法:一种新的输出分布局部平滑度度量。虚拟对抗损失被定义为模型的后验分布对每个输入数据点周围的局部扰动的鲁棒性。我们的方法类似于对抗训练,但与对抗训练的不同之处在于它仅根据输出分布确定对抗方向,并且适用于 半监督 设置。因为我们平滑模型的方向是虚拟对抗性的,所以我们将我们的方法称为虚拟对抗性训练 (VAT)。

2. 前置知识

这里我想在宏观上理清现在论文的思路。一致性正则化(Consistency Regularization),我们可以理解为如果对一个未标记的数据应用实际的扰动,则预测不应发生显著变化。实际上,一致性正则化是让拟合的函数平滑,在函数的邻域(扰动)内函数值值不会有太大的波动。其中,针对扰动就有不同的方向,对输入值的扰动(LadderNet),对模型的扰动(MeanTeacher)。其中,这两个方向不是独立的,存在在一个方法存在两个的情况。比如说, Π \Pi Π Model通过对输入值增强,可以相当于是一次对输入值的扰动,而在模型本身存在Dropout,所以也存在对模型的扰动。这其中,还有一个很特立独行的方向。我们上面谈到了扰动,其中扰动是有好坏的,也就是说对输出值的影响程度。比如说,对于输入值 X i X_i Xi,模型的输出值是 y i ^ \hat{y_i} yi^,但是现在当我们对输入值进行一次扰动时, X i ~ = X i + η i \tilde{X_i} = X_i + \eta_i Xi~=Xi+ηi,模型输出变成了 y i ~ \tilde{y_i} yi~。我们期望的是模型对于两次输出是一样的,或者说相差不大。值得注意的是,这其实与所加扰动是相关的,扰动有好坏之分。也就是论文中提到的扰动是有方向性的。而对输入添加这种优质的扰动之后,模型输出显著变化,这种就是对抗样本。我们使用对抗样本进行训练,就变成了对抗训练。

前面的理解是从论文中的思路出发的。我其实还有一个标新立异的方向。相比于监督学习,半监督学习想要利用没有标签的数据,其中有种思路就是在优化目标种加入无监督损失。怎么定义这个损失?,这些论文其实是在回答这个问题。这个其实和强化学习种的策略梯度(Policy Gradient)有异曲同工之妙。

3. 符号系统

符号含义备注
x ∈ R I x \in R^I xRI输入向量
y ∈ Q y \in Q yQ向量 x x x对应的标签
D l = { x l ( n ) , y l ( n ) ∣ n = 1 , … , N l } \mathcal{D}_l = \{x_l^{(n)}, y_l^{(n)} | n = 1, \dots, N_l \} Dl={xl(n),yl(n)n=1,,Nl}有标签数据
D u l = { x u l ( n ′ ) , y u l ( n ′ ) ∣ n ′ = 1 , … , N u l } \mathcal{D}_{ul} = \{x_{ul}^{(n')}, y_{ul}^{(n')} | n' = 1, \dots, N_{ul} \} Dul={xul(n),yul(n)n=1,,Nul}无标签数据
θ \theta θ p ( y ∣ x , θ ) p(y|x,\theta) p(yx,θ)输入映射输出的分布
θ ^ \hat{\theta} θ^模型拟合的分布参数

4. 方法

4.1 Adversarial Training

这个概念是出自大佬的论文《EXPLAINING AND HARNESSING ADVERSARIAL EXAMPLES》。

L a d v ( x l , y l , θ ) : = D [ h ( y l ) , p ( y ∣ x l + r a d v , θ ) ] (1) L_{adv}(x_l,y_l, \theta):=D[h(y_l), p(y|x_l+r_{adv},\theta)] \tag{1} Ladv(xl,yl,θ):=D[h(yl),p(yxl+radv,θ)](1)
where r a d v : = arg max ⁡ r ; ∥ r ∥ ≤ ϵ D [ h ( y l ) , p ( y ∣ x l + r , θ ) ] (2) r_{adv}:=\argmax_{r;\|r\|\leq \epsilon}D[h(y_l), p(y|x_l+r,\theta)] \tag{2} radv:=r;rϵargmaxD[h(yl),p(yxl+r,θ)](2)

体会: D D D可以理解为损失计算函数。式子1,2总体思想是使用对模型扰动最大的扰动对模型进行训练,这个就被称为是对抗训练

上面的式子有个问题,实际问题中式子2很难求解。推出式子3,线性逼近。

r a d v ≈ ϵ g ∥ g ∥ 2 , where g = ∇ x l D [ h ( y l ) , p ( y ∣ x l , θ ) ] (3) r_{adv}\approx \epsilon\frac{g}{\|g\|_2}, \text{where}\, g= \nabla_{x_l}D[h(y_l), p(y|x_l,\theta)] \tag{3} radvϵg2g,whereg=xlD[h(yl),p(yxl,θ)](3)
这个理解为普通的反向传播就很好理解了, ϵ \epsilon ϵ理解为学习率。

这里还有一个变化,就是当式子3中的范数变成 L ∞ L_{\infty} L
r a d v ≈ ϵ sign ( g ) r_{adv} \approx \epsilon \text{sign}(g) radvϵsign(g)
Notice: 这里的 sign \text{sign} sign不太确定是不是一个已有的函数。

4.2 Virtual Adversarial Training

相比于Adversarial TrainingVirtual Adversarial Training(下文称VAT)意在接触对数据真实标签的依赖。Loss Function可以被写为:

D [ q ( y ∣ x ∗ ) , p ( y ∣ x ∗ + r q a d v , θ ) ] (4) D[q(y|x_∗), p(y|x_*+r_{qadv},\theta)] \tag{4} D[q(yx),p(yx+rqadv,θ)](4)
where r q a d v : = arg max ⁡ r ; ∥ r ∥ ≤ ϵ D [ q ( y ∣ x ∗ ) , p ( y ∣ x ∗ + r , θ ) ] (5) r_{qadv}:=\argmax_{r;\|r\|\leq \epsilon}D[q(y|x_∗), p(y|x_*+r,\theta)] \tag{5} rqadv:=r;rϵargmaxD[q(yx),p(yx+r,θ)](5)

其中, q ( y ∣ x ∗ ) q(y|x_∗) q(yx)在本文中用模型的本次预测替代,也就是说 q ( y ∣ x ∗ ) = p ( y ∣ x ∗ ) q(y|x_∗) = p(y|x_*) q(yx)=p(yx)

L D S ( X ∗ , θ ) : = D [ p ( y ∣ x ∗ , θ ^ ) , p ( y ∣ x ∗ + r v a d v , θ ) ] (6) LDS(X_*, \theta) :=D[p(y|x_∗, \hat{\theta}), p(y|x_*+r_{vadv},\theta)] \tag{6} LDS(X,θ):=D[p(yx,θ^),p(yx+rvadv,θ)](6)
where r v a d v : = arg max ⁡ r ; ∥ r ∥ ≤ ϵ D [ p ( y ∣ x ∗ , θ ^ ) , p ( y ∣ x ∗ + r , θ ) ] (7) r_{vadv}:=\argmax_{r;\|r\|\leq \epsilon}D[p(y|x_∗, \hat{\theta}), p(y|x_*+r,\theta)] \tag{7} rvadv:=r;rϵargmaxD[p(yx,θ^),p(yx+r,θ)](7)

Notice: 这里 θ ^ \hat{\theta} θ^在原文中没有明说,我理解的是模型当前的参数,应该涉及到后面的反向传播,这里区分是相当于一个常数,不参与反向更新。

最后是整个Loss:
ℓ ( D l , θ ) + α R v a d v ( D l , D u l , θ ) (8) \mathcal{\ell}(\mathcal{D}_l,\theta)+\alpha\mathcal{R}_{vadv}(\mathcal{D}_l,\mathcal{D}_{ul} ,\theta)\tag{8} (Dl,θ)+αRvadv(Dl,Dul,θ)(8)
where
R v a d v ( D l , D u l , θ ) : = 1 N l + N u l ∑ x ∗ ∈ D l , D u l L D S ( X ∗ , θ ) (9) \mathcal{R}_{vadv}(\mathcal{D}_l,\mathcal{D}_{ul} ,\theta):=\frac{1}{N_l+N_{ul}}\sum_{x_*\in\mathcal{D}_l,\mathcal{D}_{ul}}LDS(X_*, \theta)\tag{9} Rvadv(Dl,Dul,θ):=Nl+Nul1xDl,DulLDS(X,θ)(9)

等式5中,关于 r r r的计算在实际算法中不能这样暴力求解,所以本次论文中关于这个问题进行了推导,我这里就不重述了,直接说结论,有兴趣的可以看论文。

r v a d v ≈ ϵ g ∥ g ∥ 2 , where g = ∇ r D [ p ( y ∣ x , θ ^ ) , p ( y ∣ x + r , θ ^ ) ] ∣ r = ϵ d (3) r_{vadv}\approx \epsilon\frac{g}{\|g\|_2}, \text{where}\, g= \nabla_{r}D[p(y|x,\hat{\theta}), p(y|x+r,\hat{\theta})]|_{r=\epsilon d} \tag{3} rvadvϵg2g,whereg=rD[p(yx,θ^),p(yx+r,θ^)]r=ϵd(3)

在这里插入图片描述

5. 代码

# D, 计算两个分布的KL散度
def kl_div_with_logit(q_logit, p_logit):q = F.softmax(q_logit, dim=1)logq = F.log_softmax(q_logit, dim=1)logp = F.log_softmax(p_logit, dim=1)qlogq = ( q *logq).sum(dim=1).mean(dim=0)qlogp = ( q *logp).sum(dim=1).mean(dim=0)return qlogq - qlogp# 求解单位向量 $\frac{g}{\|g\|_2}$
def _l2_normalize(d):d = d.numpy()d /= (np.sqrt(np.sum(d ** 2, axis=(1, 2, 3))).reshape((-1, 1, 1, 1)) + 1e-16)return torch.from_numpy(d)def vat_loss(model, ul_x, ul_y, xi=1e-6, eps=2.5, num_iters=1):# 一个正态分布的随机变量, size=ul_x.size()d = torch.Tensor(ul_x.size()).normal_()# find r_adv, num_iters在原文中等于一就很好了for i in range(num_iters):d = xi *_l2_normalize(d)d = Variable(d.cuda(), requires_grad=True)y_hat = model(ul_x + d)delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)delta_kl.backward()# 值得注意的是,这里仅仅是执行了backward计算反向梯度,没有执行step反向更新。# 获得d的梯度d = d.grad.data.clone().cpu()# 清空模型的梯度model.zero_grad()# 获得d的单位向量d = _l2_normalize(d)d = Variable(d.cuda())r_adv = eps * d# compute ldsy_hat = model(ul_x + r_adv.detach())delta_kl = kl_div_with_logit(ul_y.detach(), y_hat)return delta_kl

这篇关于【论文阅读】Virtual Adversarial Training: a Regularization Method for SL and SSL的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/385077

相关文章

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟 开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚 第一站:海量资源,应有尽有 走进“智听

AI hospital 论文Idea

一、Benchmarking Large Language Models on Communicative Medical Coaching: A Dataset and a Novel System论文地址含代码 大多数现有模型和工具主要迎合以患者为中心的服务。这项工作深入探讨了LLMs在提高医疗专业人员的沟通能力。目标是构建一个模拟实践环境,人类医生(即医学学习者)可以在其中与患者代理进行医学

2014 Multi-University Training Contest 8小记

1002 计算几何 最大的速度才可能拥有无限的面积。 最大的速度的点 求凸包, 凸包上的点( 注意不是端点 ) 才拥有无限的面积 注意 :  凸包上如果有重点则不满足。 另外最大的速度为0也不行的。 int cmp(double x){if(fabs(x) < 1e-8) return 0 ;if(x > 0) return 1 ;return -1 ;}struct poin

2014 Multi-University Training Contest 7小记

1003   数学 , 先暴力再解方程。 在b进制下是个2 , 3 位数的 大概是10000进制以上 。这部分解方程 2-10000 直接暴力 typedef long long LL ;LL n ;int ok(int b){LL m = n ;int c ;while(m){c = m % b ;if(c == 3 || c == 4 || c == 5 ||

2014 Multi-University Training Contest 6小记

1003  贪心 对于111...10....000 这样的序列,  a 为1的个数,b为0的个数,易得当 x= a / (a + b) 时 f最小。 讲串分成若干段  1..10..0   ,  1..10..0 ,  要满足x非递减 。  对于 xi > xi+1  这样的合并 即可。 const int maxn = 100008 ;struct Node{int

论文翻译:arxiv-2024 Benchmark Data Contamination of Large Language Models: A Survey

Benchmark Data Contamination of Large Language Models: A Survey https://arxiv.org/abs/2406.04244 大规模语言模型的基准数据污染:一项综述 文章目录 大规模语言模型的基准数据污染:一项综述摘要1 引言 摘要 大规模语言模型(LLMs),如GPT-4、Claude-3和Gemini的快

论文阅读笔记: Segment Anything

文章目录 Segment Anything摘要引言任务模型数据引擎数据集负责任的人工智能 Segment Anything Model图像编码器提示编码器mask解码器解决歧义损失和训练 Segment Anything 论文地址: https://arxiv.org/abs/2304.02643 代码地址:https://github.com/facebookresear

模版方法模式template method

学习笔记,原文链接 https://refactoringguru.cn/design-patterns/template-method 超类中定义了一个算法的框架, 允许子类在不修改结构的情况下重写算法的特定步骤。 上层接口有默认实现的方法和子类需要自己实现的方法

消除安卓SDK更新时的“https://dl-ssl.google.com refused”异常的方法

消除安卓SDK更新时的“https://dl-ssl.google.com refused”异常的方法   消除安卓SDK更新时的“https://dl-ssl.google.com refused”异常的方法 [转载]原地址:http://blog.csdn.net/x605940745/article/details/17911115 消除SDK更新时的“

论文翻译:ICLR-2024 PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS

PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS https://openreview.net/forum?id=KS8mIvetg2 验证测试集污染在黑盒语言模型中 文章目录 验证测试集污染在黑盒语言模型中摘要1 引言 摘要 大型语言模型是在大量互联网数据上训练的,这引发了人们的担忧和猜测,即它们可能已