WGAN 算法

2024-09-05 05:20
文章标签 算法 wgan

本文主要是介绍WGAN 算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        因为要进行 MinMax 操作,所以 GAN 是很不好训练的。我们接下来介绍一个 GAN 训练的小技巧,就是著名的Wasserstein GAN(Wasserstein Generative Adversarial Network)。在讲这个之前,我们分析下JS散度有什么问题。首先,JS散度的两个输入PG 和Pdata 之间的重叠部分往往非常少。这个其实也是可以预料到的,我们从不同的角度来看: 图片其实是高维空间里低维的流形,因为在高维空间中随便采样一个点,它通常都没有办法 构成一个人物的头像,所以人物头像的分布,在高维的空间中其实是非常狭窄的。换个角度 解释,如果是以二维空间为例,图片的分布可能就是二维空间的一条线,也就是PG和Pdata 都是二维空间中的两条直线。而二维空间中的两条直线,除非它们刚好重合,否则它们相交 的范围是几乎可以忽略的。从另一个角度解释,我们从来都不知道PG和Pdata的具体分布, 因为其源于采样,所以也许它们是有非常小的重叠分布范围。比如采样的点不够多,就算是这 两个分布实际上很相似,也很难有任何的重叠的部分。

        所以以上的问题就会对于JS分布造成以下问题:首先,对于两个没有重叠的分布,JS散 度的值都为Log2,与具体的分布无关。就算两个分布都是直线,但是它们的距离不一样,得 到的JS 散度的值就都会是Log2,如图1 所示。所以JS散度的值并不能很好地反映两个 分布的差异。另外,对于两个有重叠的分布,JS散度的值也不一定能够很好地反映两个分布 的差异。因为JS散度的值是有上限的,所以当两个分布的重叠部分很大时,JS散度不好区 分不同分布间的差异。所以既然从JS散度中,看不出来分布的差异。那么在训练的时候,我 们就很难知道我们的生成器有没有在变好,我们也很难知道我们的判别器有没有在变好。所 以我们需要一个更好的衡量两个分布差异的指标。

图1 JS 散度的局限性

        我们从更直观的实际操作角度来说明,当使用JS散度训练一个二分类的分类器,来去分 辨真实和生成的图片时,会发现实际上正确率几乎都是100%。原因在于采样的图片根本就没 几张,对于判别器来说,采样256张真实图片和256张假的图片它直接用硬背的方法都可以 把这两组图片分开。所以实际上如果用二分类的分类器训练判别器下去,识别正确率都会是 100%。根本就没有办法看出生成器有没有越来越好。所以过去尤其是在还没有WGAN这样 的技术时,训练GAN真的就很像盲盒。根本就不知道训练的时候,模型有没有越来越好,所 以旧时的做法是每次更新几次生成器后,就需要把图片打印可视化出来看。然后就要一边吃 饭,一边看图片生成的结果,然后跑一跑就发现内存报错了就需要重新再来,所以过去训练 GAN 真的是很辛苦的。这也不像我们在训练一般的网络的时候,有个损失函数,然后那个损失值随着训练的过程慢慢变小,当我们看到损失值慢慢变小时,我们就放心网络有在训练。但 是对于GAN 而言,我们根本就没有这样的指标,所以我们需要一个更好的衡量两个分布差 异的指标。否则只能够用人眼看,用人眼守在电脑前面看,发现结果不好,就重新用一组超参 数调一下网络。

        既然是JS 散度的问题,肯定有人就想问说会不会换一个衡量两个分布相似程度的方式, 就可以解决这个问题了呢?是的,于是就有了Wasserstein,或使用Wasserstein 距离的想法。 Wasserstein 距离的想法如下,假设两个分布分别为 P 和 Q,我们想要知道这两个分布的差 异,我们可以想像有一个推土机,它可以把P 这边的土堆挪到Q这边,那么推土机平均走的 距离就是Wasserstein 距离。在这个例子里面,我们假设P 集中在一个点,Q集中在一个点, 对推土机而言,假设它要把P 这边的土挪到Q这边,那它要平均走的距离就是D,那么P 和Q的Wasserstein 距离就是 D。但是如果 P 和 Q 的分布不是集中在一个点,而是分布在 一个区域,那么我们就要考虑所有的可能性,也就是所有的推土机的走法,然后看平均走的 距离是多少,这个平均走的距离就是Wasserstein 距离。Wasserstein 距离可以想象为有一个 推土机在推土,所以Wasserstein 距离也称为推土机距离(Earth Mover’s Distance,EMD)。 所以Wasserstein 距离的定义如图2 所示。

图2 Wasserstein 距离的定义

        但是如果是更复杂的分布,要算Wasserstein 距离就有点困难了,如图 3 所示。假设 两个分布分别是P 和Q,我们要把P 变成Q,那有什么样的做法呢?我们可以把P 的土 搬到Q来,也可以反过来把Q的土搬到P。所以当我们考虑比较复杂分布的时候,两种分 布计算距离就有很多不同的方法,即不同的“移动”方式,从中计算算出来的距离,即推土机平 均走的距离就不一样。对于左边这个例子,推土机平均走的距离比较少;右边这个例子因为 舍近求远,所以推土机平均走的距离比较大。那两个分布P 和Q的Wasserstein 距离会有 很多不同的值吗?这样的话,我们就不知道到底要用哪一个值来当作是Wasserstein 距离了。 为了让Wasserstein 距离只有一个值,我们将距离定义为穷举所有的“移动”方式,然后看哪一 个推土的方法可以让平均的距离最小。那个最小的值才是Wasserstein距离。所以其实要计算 Wasserstein 距离挺麻烦的,因为里面还要解一个优化问题。

图 3 Wasserstein 距离的可视化理解

        我们这里先避开这个问题,先来看看Wasserstein距离有什么好处,如图4所示。假设 两个分布PG 和Pdata 它们的距离是 d0,那在这个例子中,Wasserstein 距离算出来就是 d0。 同样的,假设两个分布PG和Pdata 它们的距离是d1,那在这个例子中,Wasserstein 距离算 出来就是d1。假设d1小于d0,那d1 的Wasserstein距离就会小于d0。所以Wasserstein距离 可以很好地反映两个分布的差异。从左到右我们的生成器越来越进步,但是如果同时观察判别 器,你会发现你观察不到任何规律。因为对于判别器而言,每一个例子算出来的JS散度,都是 一样的Log2,所以判别器根本就看不出来这边的分布有没有变好。但是如果换成Wasserstein 距离,由左向右的时候我们会知道,我们的生成器做得越来越好。所以我们的Wasserstein距 离越小,对应的生成器就越好。这就是为什么我们要用Wasserstein距离的原因,我们换一个 计算差异的方式,就可以解决JS距离有可能带来的问题。

图 4 Wasserstein 距离与 JS 距离的对比

        可以再举一个演化的例子——人类眼睛的生成。人类的眼睛是非常复杂的,它是由其他 原始的眼睛演化而来的。比如说有一些细胞具备有感光的能力,这可以看做是最原始的眼睛。 那这些最原始的眼睛怎么变成最复杂的眼睛呢?它只是一些感光的细胞在皮肤上经过一系列 的突变产生更多的感光细胞,中间有很多连续的步骤。举例来说,感光的细胞可能会出现在 一个比较凹陷的地方,皮肤凹陷下去,这样感光细胞可以接受来自不同方向的光源。然后慢 慢地把凹陷的地方保护住并在里面放一些液体,最后就变成了人的眼睛。所以这个过程是一 个连续的过程,是一个从简单到复杂的过程。当使用WGAN时,使用Wasserstein距离来衡 量分布间的偏差的时候,其实就制造了类似的效果。本来两个分布PG0 和Pdata 距离非常遥 远,你要它一步从开始就直接跳到结尾,这是非常困难的。但是如果用Wasserstein距离,你 可以让PG0 和Pdata 慢慢挪近到一起,可以让它们的距离变小一点,然后再变小一点,最后 就可以让它们对齐在一起。所以这就是为什么我们要用Wasserstein距离的原因,因为它可以 让我们的生成器一步一步地变好,而不是一下子就变好。

        所以WGAN实际上就是用Wasserstein距离来取代JS距离,这个GAN就叫做WGAN。 那接下来的问题是,Wasserstein 距离是要如何计算呢?我们可以看到,Wasserstein 距离的定 义是一个最优化的问题,如图5所示。这里我们简化过程直接介绍结果,也就是解图中最 大化问题的解,解出来以后所得到的值就是Wasserstein距离,即PG0和Pdata 的Wasserstein 距离。我们观察一下图5 的公式,即我们要找一个函数D,这个函数D 是一个函数,我 们可以想像成是一个神经网络,这个神经网络的输入是x,输出是D(x)。X 如果是从Pdata 采样来的,我们要计算它的期望值Ex∼Pdata ,如果X 是从PG 采样来的,那我们要计算它的 期望值Ex∼PG ,然后再乘上一个负号,所以如果要最大化这个目标函数就会达成。如果X 如 果是从Pdata 采样得到的,那么判别器的输出要越大越好,如果X 如果是从PG采样得到的, 从生成器采样出来的输出应该要越小越好。 

图 5 Wasserstein 距离的计算 

        此外还有另外一个限制。函数D必须要是一个1-Lipschitz的函数。我们可以想像成,如 果有一个函数的斜率是有上限的(足够平滑,变化不剧烈),那这个函数就是1-Lipschitz 的 函数。如果没有这个限制,只看大括号里面的值只单纯要左边的值越大越好,右边的值越小 越好,那么在蓝色的点和绿色的点,也就是真正的图像和生成的图像没有重叠的时候,我们 可以让左边的值无限大,右边的值无限小,这样的话,这个目标函数就可以无限大。这时整个 训练过程就根本就没有办法收敛。所以我们要加上这个限制,让这个函数是一个1-Lipschitz 的函数,这样的话,左边的值无法无限大,右边的值无法无限小,所以这个目标函数就可以收 敛。所以当判别器够平滑的时候,假设真实数据和生成数据的分布距离比较近,那就没有办 法让真实数据的期望值非常大,同时生成的值非常小。因为如果让真实数据的期望值非常大, 同时生成的值非常小,那它们中间的差距很大,判别器的更新变化就很剧烈,它就不平滑了, 也就不是1-Lipschitz 了。

        那接下来的问题就是如何确保判别器一定符合 1-Lipschitz 函数的限制呢?其实最早刚 提出WGAN 的时候也没有什么好想法。最早的一篇WGAN的文章做了一个比较粗糙的处 理,就是训练网络时,把判别器的参数限制在一个范围内,如果超过这个范围,就把梯度下 降更新后的权重设为这个范围的边界值。但其实这个方法并不一定真的能够让判别器变成1 Lipschitz 函数。虽然它可以让判别器变得平滑,但是它并没有真的去解这个优化问题,它并 没有真的让判别器符合这个限制。

        后来就有了一些其它的方法,例如说有一篇文章叫做ImprovedWGAN,它就是使用了梯度惩罚(gradient penalty)的方法,这个方法可以让判别器变成 1-Lipschitz 函数。具体来 说,如图6所示,假设蓝色区域是真实数据的分布,橘色是生成数据的分布,在真实数据这 边采样一个数据,在生成数据这边取一个样本,然后在这两个点之间取一个中间的点,然后计 算这个点的梯度,使之接近于1。就是在判别器的目标函数里面,加上一个惩罚项,这个惩罚 项就是判别器的梯度的范数减去1的平方,这个惩罚项的系数是一个超参数,这个超参数可 以让你的判别器变得越平滑。在Improved WGAN 之后,还有 Improved Improved WGAN, 就是把这个限制再稍微改一改。另外还有方法是将判别器的参数限制在一个范围内,让它是 1-Lipschitz 函数,这个叫做谱归一化。总之,这些方法都可以让判别器变成1-Lipschitz 函数, 但是这些方法都有一个问题,就是它们都是在判别器的目标函数里面加了一个惩罚项,这个 惩罚项的系数是一个超参数,这个超参数会让你的判别器变得越平滑。 

图6  ImprovedWGAN 的梯度惩罚

这篇关于WGAN 算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1138062

相关文章

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

【数据结构】——原来排序算法搞懂这些就行,轻松拿捏

前言:快速排序的实现最重要的是找基准值,下面让我们来了解如何实现找基准值 基准值的注释:在快排的过程中,每一次我们要取一个元素作为枢纽值,以这个数字来将序列划分为两部分。 在此我们采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,将中间数作为枢纽值。 快速排序实现主框架: //快速排序 void QuickSort(int* arr, int left, int rig

poj 3974 and hdu 3068 最长回文串的O(n)解法(Manacher算法)

求一段字符串中的最长回文串。 因为数据量比较大,用原来的O(n^2)会爆。 小白上的O(n^2)解法代码:TLE啦~ #include<stdio.h>#include<string.h>const int Maxn = 1000000;char s[Maxn];int main(){char e[] = {"END"};while(scanf("%s", s) != EO

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

dp算法练习题【8】

不同二叉搜索树 96. 不同的二叉搜索树 给你一个整数 n ,求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种?返回满足题意的二叉搜索树的种数。 示例 1: 输入:n = 3输出:5 示例 2: 输入:n = 1输出:1 class Solution {public int numTrees(int n) {int[] dp = new int

Codeforces Round #240 (Div. 2) E分治算法探究1

Codeforces Round #240 (Div. 2) E  http://codeforces.com/contest/415/problem/E 2^n个数,每次操作将其分成2^q份,对于每一份内部的数进行翻转(逆序),每次操作完后输出操作后新序列的逆序对数。 图一:  划分子问题。 图二: 分而治之,=>  合并 。 图三: 回溯:

最大公因数:欧几里得算法

简述         求两个数字 m和n 的最大公因数,假设r是m%n的余数,只要n不等于0,就一直执行 m=n,n=r 举例 以18和12为例 m n r18 % 12 = 612 % 6 = 06 0所以最大公因数为:6 代码实现 #include<iostream>using namespace std;/