【RL】Wasserstein距离-GAN背后的直觉

2023-11-21 00:40

本文主要是介绍【RL】Wasserstein距离-GAN背后的直觉,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、说明

        在本文中,我们将阅读有关Wasserstein GANs的信息。具体来说,我们将关注以下内容:i)什么是瓦瑟斯坦距离?,ii)为什么要使用它?iii) 我们如何使用它来训练 GAN?

二、Wasserstein距离概念

        Wasserstein距离,又称为Earth Mover's Distance (EMD),是衡量两个概率分布之间的差异程度的一种数学方式。它考虑了分布之间的距离和它们之间的“传输成本”。

        简单来说,Wasserstein距离将两个分布看作“堆积在地图上的土堆”,并计算将一个堆移到另一个的最小成本。这个距离度量的优点是它能够处理非均匀分布,并且能够考虑分布的形状和结构。

        Wasserstein距离在机器学习领域中应用非常广泛,特别是在生成模型中用来评估生成器生成的图像与真实图像之间的差异。

图1:学习区分两个高斯时的最优判别器和批评者[1]。

2.1 瓦瑟施泰因距离

        Wasserstein 距离(地球移动器的距离)是给定度量空间上两个概率分布之间的距离度量。直观地说,它可以被视为将一个分布转换为另一个分布所需的最小功,其中功被定义为必须移动的分布的质量和要移动的距离的乘积。在数学上,它被定义为:

方程 1:瓦瑟斯坦 分布P_r和P_g之间的距离。

        在方程1中,Π(P_r,P_g)是x和y上所有联合分布的集合,使得边际分布等于P_r和P_g。 γ(x, y)可以看作是必须从x移动到y才能将P_r转换为P_g的质量量[1]。因此,瓦瑟斯坦距离是最佳运输计划的成本。

2.2 瓦瑟斯坦距离 vs. 詹森-香农分歧

        最初的GAN目标被证明是Jensen-Shannon分歧的最小化[2]。JS背离定义为:

方程 2:P_r 和 P_g 之间的 JS 背离 P_m = (P_r + P_g)/2

        

        与JS相比,Wasserstein距离具有以下优点:

  • Wasserstein 距离是连续的,几乎可以在任何地方微分,这使我们能够训练模型达到最佳状态。
  • 随着鉴别器的变好,JS散度局部饱和,因此梯度变为零并消失。
  • Wasserstein 距离是一个有意义的度量,即当分布彼此靠近时,它收敛到 0,当它们越来越远时发散。
  • 作为目标函数的 Wasserstein 距离比使用 JS 散度更稳定。当使用Wasserstein距离作为目标函数时,模式崩溃问题也得到了缓解。

        从图 1 我们清楚地看到,最佳GAN鉴别器饱和并导致梯度消失,而优化Wasserstein距离的WGAN评论家在整个过程中具有稳定的梯度。

        有关数学证明和更详细的研究,请查看此处的论文!

三、瓦瑟斯坦·GAN

        现在可以清楚地看到,优化 Wasserstein 距离比优化 JS 散度更有意义,还需要注意的是,方程 1 中定义的 Wasserstein 距离非常棘手[3],因为我们不可能计算所有 γ ∈Π(Pr ,Pg) 的下界(最大下界)。然而,从坎托罗维奇-鲁宾斯坦二元性中,我们有,

公式3:1-利普希茨条件下的瓦瑟斯坦距离。

        这里我们有 W(P_r, P_g) 作为所有 1-Lipschitz 函数 f: X → R 的上确界(最低上限)。

        K-利普希茨连续性:给定 2 个度量空间 (X, d_X) 和 (Y, d_Y),变换函数 f: X → Y 是 K-利普希茨连续的,如果

公式3:K-Lipschitz连续性。

        其中d_X和d_Y是各自度量空间中的距离函数。当一个函数是 K-Lipschitz 时,从方程 2 开始,我们最终得到 K ∙ W(P_r, P_g)。

        现在,如果我们有一系列参数化函数 {f_w},其中 w∈W 是 K-Lipschitz 连续的,我们可以有

公式 4

即,w∈W 最大化方程 4 给出瓦瑟斯坦距离乘以一个常数。

四、WGAN评论家

        为此,WGAN引入了一个批评者,而不是我们在GAN中了解到的鉴别器。批评者网络在设计上类似于判别器网络,但通过优化找到将最大化方程 4 的 w* 来预测 Wasserstein 距离。为此,批评家的客观功能如下:

公式5:批评家客观函数。

       在这里,为了在函数f上强制执行Lipschitz连续性,作者诉诸于将权重w限制在一个紧凑的空间内。这是通过将砝码夹紧到一个小范围(论文中的[-1e-2,1e-2][1])来完成的。

鉴别器和批评者之间的区别在于,鉴别器经过训练以正确识别P_r样本和P_g样本,批评家估计P_r和P_g之间的Wasserstein距离。

这是训练批评家的python代码。

for ix in n_critic_steps:opt_critic.zero_grad()real_images = data[0].float().to(device)# * Generate imagesnoise = sample_noise()fake_images = netG(noise)# * though they are name so, they are not logits!real_logits = netCritic(real_images)fake_logits = netCritic(fake_images)# * max E_{x~P_X}[C(x)] - E_{Z~P_Z}[C(g(z))]loss = -(real_logits.mean() - fake_logits.mean())loss.backward(retain_graph=True)opt_critic.step()# * Gradient clipplingfor p in netCritic.parameters():p.data.clamp_(-self.c, self.c)

五、WGAN生成器目标

        当然,发电机的目标是最小化P_r和P_g之间的瓦瑟斯坦距离。生成器试图找到最小化P_g和P_r之间的 Wasserstein 距离的 θ*。为此,生成器的目标函数如下:

        公式 6:生成器目标函数。

        在这里,WGAN生成器和标准生成器之间的主要区别再次在于,WGAN生成器试图最小化P_r和P_g之间的Wasserstein距离,而标准生成器试图用生成的图像欺骗鉴别器。

        以下是训练生成器的 python 代码:

opt_gen.zero_grad()noise = sample_noise()fake_images = netG(noise)# again, these are not logits.
fake_logits = netCritic(fake_images)# * - E_{Z~P_Z}[C(g(z))]
loss = -fake_logits.mean().view(-1)loss.backward()
opt_gen.step()

六、培训结果

fig2:WGAN训练的早期结果[3]。

        图例.2显示了训练WGAN的一些早期结果。请注意,图 2 中的图像是早期结果,一旦确认模型按预期训练,训练就会停止。

七、代码

        Wasserstein GAN的完整实现可以在这里找到[3]。

八、结论

        WGAN提供非常稳定的培训和有意义的培训目标。本文介绍并直观地解释了什么是 Wasserstein 距离,Wasserstein 距离相对于标准 GAN 使用的 Jensen-Shannon 散度的优势,以及如何使用 Wasserstein 距离来训练 WGAN。我们还看到了用于训练 Critic 和生成器的代码片段,以及早期训练模型的大量输出。尽管WGAN比标准GAN具有许多优势,但WGAN论文的作者明确承认,权重裁剪不是执行Lipschitz连续性的最佳方法[1]。为了解决这个问题,他们提出了带有梯度惩罚的Wasserstein GAN[4],我们将在后面的文章中讨论。

        如果您喜欢这个,请查看本系列的下一篇文章,其中讨论了 WGAN-GP!

这篇关于【RL】Wasserstein距离-GAN背后的直觉的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/398583

相关文章

线性代数|机器学习-P35距离矩阵和普鲁克问题

文章目录 1. 距离矩阵2. 正交普鲁克问题3. 实例说明 1. 距离矩阵 假设有三个点 x 1 , x 2 , x 3 x_1,x_2,x_3 x1​,x2​,x3​,三个点距离如下: ∣ ∣ x 1 − x 2 ∣ ∣ 2 = 1 , ∣ ∣ x 2 − x 3 ∣ ∣ 2 = 1 , ∣ ∣ x 1 − x 3 ∣ ∣ 2 = 6 \begin{equation} ||x

生成对抗网络(GAN网络)

Generative Adversarial Nets 生成对抗网络GAN交互式可视化网站 1、GAN 基本结构 GAN 模型其实是两个网络的组合: 生成器(Generator) 负责生成模拟数据; 判别器(Discriminator) 负责判断输入的数据是真实的还是生成的。 生成器要不断优化自己生成的数据让判别网络判断不出来,判别器也要优化自己让自己判断得更准确。 二者关系形成

深度学习--对抗生成网络(GAN, Generative Adversarial Network)

对抗生成网络(GAN, Generative Adversarial Network)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN主要用于生成数据,通过两个神经网络相互对抗,来生成以假乱真的新数据。以下是对GAN的详细阐述,包括其概念、作用、核心要点、实现过程、代码实现和适用场景。 1. 概念 GAN由两个神经网络组成:生成器(Generator)和判别器(D

重启顺风车的背后,是高德难掩的“野心”

以史鉴今,我们往往可以从今天的事情中,看到古人的智慧,也看到时代的进步。就如西汉后期文学家恒宽曾说的,“明者因时而变,知者随事而制”。 图源来自高德官方 近日,高德就展现了这样的智慧。在网约车市场陷入饱和状态时,高德审时度势,宣布重启顺风车业务,并在全国范围内大规模启动,首批覆盖珠三角、长三角及湖北省武汉市等共计65座城市,完成在出行服务领域的又一重要布局。 重启顺风车,增量市场的“蛋糕

模拟退火求n个点到某点距离和最短

/*找出一个点使得这个店到n个点的最长距离最短,即求最小覆盖圆的半径用一个点往各个方向扩展,如果结果更优,则继续以当前步长扩展,否则缩小步长*/#include<stdio.h>#include<math.h>#include<string.h>const double pi = acos(-1.0);struct point {double x,y;}p[1010];int

黑神话:悟空》增加草地绘制距离MOD使游戏场景看起来更加广阔与自然,增强了游戏的沉浸式体验

《黑神话:悟空》增加草地绘制距离MOD为玩家提供了一种全新的视觉体验,通过扩展游戏中草地的绘制距离,增加了场景的深度和真实感。该MOD通过增加草地的绘制距离,使游戏场景看起来更加广阔与自然,增强了游戏的沉浸式体验。 增加草地绘制距离MOD安装 1、在%userprofile%AppDataLocalb1SavedConfigWindows目录下找到Engine.ini文件。 2、使用记事本编辑

黑神话悟空背后的技术揭秘与代码探秘

《重塑神话:黑神话悟空背后的技术揭秘与代码探秘》 引言 在国产游戏领域,《黑神话:悟空》无疑是一颗璀璨的明星,它不仅融合了深厚的中国文化元素,更在技术上实现了诸多突破,为玩家带来了前所未有的沉浸式体验。本文将深入剖析《黑神话:悟空》背后的关键技术,并通过代码案例展示其技术实现的魅力。 一、高精度动作捕捉技术 《黑神话:悟空》中的角色动作之所以如此逼真,得益于高精度动作捕捉技术的应用

SimD:基于相似度距离的小目标检测标签分配

摘要 https://arxiv.org/pdf/2407.02394 由于物体尺寸有限且信息不足,小物体检测正成为计算机视觉领域最具挑战性的任务之一。标签分配策略是影响物体检测精度的关键因素。尽管已经存在一些针对小物体的有效标签分配策略,但大多数策略都集中在降低对边界框的敏感性以增加正样本数量上,并且需要设置一些固定的超参数。然而,更多的正样本并不一定会带来更好的检测结果,事实上,过多的正样本

Matlab)实现HSV非等间隔量化--相似判断:欧式距离--输出图片-

%************************************************************************** %                                 图像检索——提取颜色特征 %HSV空间颜色直方图(将RGB空间转化为HS

微积分直觉:隐含微分

目录 一、介绍 二、梯子问题 三、结论 四、一个额外的例子 一、介绍         让我们想象一个半径为 5 的圆,以 xy 平面为中心。现在假设我们想在点 (3,4) 处找到一条切线到圆的斜率。         好吧,为了做到这一点,我们必须非常接近圆和切线之间的空间,并沿着该曲线迈出一小步。该步骤的 y 分量为 dy,x 分量为