从GAN到WGAN(01/2)

2024-06-09 17:28
文章标签 01 gan wgan

本文主要是介绍从GAN到WGAN(01/2),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

从GAN到WGAN

文章目录

  • 一、说明
  • 二、Kullback-Leibler 和 Jensen-Shannon 背离
  • 三、生成对抗网络 (GAN)
  • 四、D 的最优值是多少?
  • 五、什么是全局最优?
  • 六、损失函数代表什么?
  • 七、GAN中的问题

一、说明

生成对抗网络 (GAN) 在许多生成任务中显示出巨大的效果,以复制现实世界的丰富内容,如图像、人类语言和音乐。它的灵感来自博弈论:两个模型,一个生成器和一个批评家,在相互竞争的同时使彼此更强大。然而,训练GAN模型是相当具有挑战性的,因为人们面临着训练不稳定或收敛失败等问题。

在这里,我想解释一下生成对抗网络框架背后的数学原理,为什么很难训练,最后介绍一个旨在解决训练难点的GAN修改版本。

二、Kullback-Leibler 和 Jensen-Shannon 背离

在我们开始仔细研究 GAN 之前,让我们首先回顾一下用于量化两个概率分布之间相似性的两个指标。

  1. KL(Kullback-Leibler)散度衡量一个概率分布如何偏离第二个预期概率分布
    .
    D K L ( p ∥ q ) = ∫ x p ( x ) log ⁡ p ( x ) q ( x ) d x D_{KL}(p \| q) = \int_x p(x) \log \frac{p(x)}{q(x)} dx DKL(pq)=xp(x)logq(x)p(x)dx

在以下情况下达到最小零点: p ( x ) = = q ( x ) p(x)==q(x) p(x)==q(x) 处处成立。
该公式表明 KL 散度是不对称的。当 p ( x ) p(x) p(x)接近零而 q ( x ) q(x) q(x)仍然显着大于零时, q q q的影响被忽略。当尝试测量两个同等重要的分布之间的相似性时,这可能会导致有问题的结果。

  1. Jensen-Shannon 散度是两个概率分布之间相似性的另一种度量,以.JS发散是对称的(耶!),而且更平滑。如果您有兴趣阅读有关 KL 背离和 JS 背离之间比较的更多信息,请查看这篇 Quora 帖子。
    D J S ( p ∥ q ) = 1 2 D K L ( p ∥ p + q 2 ) + 1 2 D K L ( q ∥ p + q 2 ) D_{JS}(p \| q) = \frac{1}{2} D_{KL}(p \| \frac{p + q}{2}) + \frac{1}{2} D_{KL}(q \| \frac{p + q}{2}) DJS(pq)=21DKL(p2p+q)+21DKL(q2p+q) 在这里插入图片描述
    图 1.给定两个高斯分布,平均值=0 且 std=1 且平均值=1,标准度=1。两个分布的平均值标记为.吉隆坡背离是不对称的,但 JS 发散是对称的。

一些人认为(Huszar,2015)GANs取得巨大成功背后的一个原因是将损失函数从传统最大似然方法中的非对称KL散度转换为对称JS散度。我们将在下一节中详细讨论这一点。

三、生成对抗网络 (GAN)

GAN由两个模型组成:

  • 鉴别器:估计给定样本来自真实数据集的概率。它作为评论家工作,并经过优化以区分假样品和真样品。
  • 生成器: 输出给定噪声变量输入的合成样本 (带来潜在的产出多样性)。它被训练为捕获真实的数据分布,以便其生成样本尽可能真实,或者换句话说,可以欺骗鉴别器提供高概率。
  • 在这里插入图片描述

图 2.生成对抗网络的架构。(图片来源:www.kdnuggets.com/2017/01/generative-…-learning.html)
这两个模型在训练过程中相互竞争:生成器在努力欺骗鉴别者,而批评者模型
正在努力不被骗。两种模型之间这种有趣的零和博弈激励了两者改进其功能。

鉴于

SymbolMeaningNotes
p z p_z pzData distribution over noise input zUsually, just uniform.
p g p_g pgThe generator’s distribution over data x
p r p_r prData distribution over real sample x

一方面,我们的目标是通过最大化 E x ∼ p r ( x ) [ log ⁡ D ( x ) ] \mathbb{E}_{x \sim p_{r}(x)} [\log D(x)] Expr(x)[logD(x)] 来确保判别器 D D D 对真实数据做出准确的决策。另一方面,对于假样本 G ( z ) , z ∼ p z ( z ) G(z), z \sim p_z(z) G(z),zpz(z),判别器应该输出一个接近于零的概率 D ( G ( z ) ) D(G(z)) D(G(z)),这是通过最大化 E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim p_{z}(z)} [\log (1 - D(G(z)))] Ezpz(z)[log(1D(G(z)))]
另一方面,生成器经过训练以增强 D D D 为虚假示例分配高概率的可能性,从而最小化 E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim p_{z}(z)} [\log (1 - D(G(z)))] Ezpz(z)[log(1D(G(z)))]
当整合这两个方面时,D 和 G 进行极小极大博弈,其目标是优化后续损失函数:
min ⁡ G max ⁡ D L ( D , G ) = E x ∼ p r ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] = E x ∼ p r ( x ) [ log ⁡ D ( x ) ] + E x ∼ p g ( x ) [ log ⁡ ( 1 − D ( x ) ] \begin{aligned} \min_G \max_D L(D, G) & = \mathbb{E}_{x \sim p_{r}(x)} [\log D(x)] + \mathbb{E}_{z \sim p_z(z)} [\log(1 - D(G(z)))] \\ & = \mathbb{E}_{x \sim p_{r}(x)} [\log D(x)] + \mathbb{E}_{x \sim p_g(x)} [\log(1 - D(x)] \end{aligned} GminDmaxL(D,G)=Expr(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]=Expr(x)[logD(x)]+Expg(x)[log(1D(x)]
E x ∼ p r ( x ) [ log ⁡ D ( x ) ] \mathbb{E}_{x \sim p_{r}(x)} [\log D(x)] Expr(x)[logD(x)]在梯度下降更新期间对 G 没有影响。

四、D 的最优值是多少?

现在我们有一个明确定义的损失函数。让我们首先检查一下什么是D最佳值
.
L ( G , D ) = ∫ x ( p r ( x ) log ⁡ ( D ( x ) ) + p g ( x ) log ⁡ ( 1 − D ( x ) ) ) d x L(G, D) = \int_x \bigg( p_{r}(x) \log(D(x)) + p_g (x) \log(1 - D(x)) \bigg) dx L(G,D)=x(pr(x)log(D(x))+pg(x)log(1D(x)))dx

由于我们感兴趣的是确定 D(x) 的最佳值以最大化 L(G,D),因此我们将贴上标签
x ~ = D ( x ) , A = p r ( x ) , B = p g ( x ) \tilde{x} = D(x), A=p_{r}(x), B=p_g(x) x~=D(x),A=pr(x),B=pg(x)
然后是积分内部的东西(我们可以安全地忽略积分,因为对所有可能的值进行采样)为:
f ( x ~ ) = A l o g x ~ + B l o g ( 1 − x ~ ) d f ( x ~ ) d x ~ = A 1 l n 10 1 x ~ − B 1 l n 10 1 1 − x ~ = 1 l n 10 ( A x ~ − B 1 − x ~ ) = 1 l n 10 A − ( A + B ) x ~ x ~ ( 1 − x ~ ) \begin{aligned} f(\tilde{x}) & = A log\tilde{x} + B log(1-\tilde{x}) \\ \frac{d f(\tilde{x})}{d \tilde{x}} & = A \frac{1}{ln10} \frac{1}{\tilde{x}} - B \frac{1}{ln10} \frac{1}{1 - \tilde{x}} \\ & = \frac{1}{ln10} (\frac{A}{\tilde{x}} - \frac{B}{1-\tilde{x}}) \\ & = \frac{1}{ln10} \frac{A - (A + B)\tilde{x}}{\tilde{x} (1 - \tilde{x})} \\ \end{aligned} f(x~)dx~df(x~)=Alogx~+Blog(1x~)=Aln101x~1Bln1011x~1=ln101(x~A1x~B)=ln101x~(1x~)A(A+B)x~

因此,设置 d f ( x ~ ) d x ~ = 0 \frac{d f(\tilde{x})}{d \tilde{x}} = 0 dx~df(x~)=0,我们得到鉴别器的最佳值:

D ∗ ( x ) = x ~ ∗ = A A + B = p r ( x ) p r ( x ) + p g ( x ) ∈ [ 0 , 1 ] D^*(x) = \tilde{x}^* = \frac{A}{A + B} = \frac{p_{r}(x)}{p_{r}(x) + p_g(x)} \in [0, 1] D(x)=x~=A+BA=pr(x)+pg(x)pr(x)[0,1]
.
Once the generator is trained to its optimal,
p g p_g pg gets very close to p r p_r pr. When p g = p r p_g = p_{r} pg=pr, D ∗ ( x ) D^*(x) D(x) becomes 1/2.
.

一旦生成器被训练到最佳状态, p g p_g pg 非常接近 p r p_r pr。当 p g = p r p_g = p_{r} pg=pr时, D ∗ ( x ) D^*(x) D(x)变为1/2。

五、什么是全局最优?

当两者都G和D处于最佳值,我们有 p g = p r p_g = p_{r} pg=pr D ∗ ( x ) = 1 / 2 D^*(x)=1/2 D(x)=1/2。损失函数变为:

L ( G , D ∗ ) = ∫ x ( p r ( x ) log ⁡ ( D ∗ ( x ) ) + p g ( x ) log ⁡ ( 1 − D ∗ ( x ) ) ) d x = log ⁡ 1 2 ∫ x p r ( x ) d x + log ⁡ 1 2 ∫ x p g ( x ) d x = − 2 log ⁡ 2 \begin{aligned} L(G, D^*) &= \int_x \bigg( p_{r}(x) \log(D^*(x)) + p_g (x) \log(1 - D^*(x)) \bigg) dx \\ &= \log \frac{1}{2} \int_x p_{r}(x) dx + \log \frac{1}{2} \int_x p_g(x) dx \\ &= -2\log2 \end{aligned} L(G,D)=x(pr(x)log(D(x))+pg(x)log(1D(x)))dx=log21xpr(x)dx+log21xpg(x)dx=2log2

六、损失函数代表什么?

根据上一节中列出的公式,JS 之间的背离 p g p_g pg p r p_r pr,可以计算为:
D J S ( p r ∥ p g ) = 1 2 D K L ( p r ∣ ∣ p r + p g 2 ) + 1 2 D K L ( p g ∣ ∣ p r + p g 2 ) = 1 2 ( log ⁡ 2 + ∫ x p r ( x ) log ⁡ p r ( x ) p r + p g ( x ) d x ) + 1 2 ( log ⁡ 2 + ∫ x p g ( x ) log ⁡ p g ( x ) p r + p g ( x ) d x ) = 1 2 ( log ⁡ 4 + L ( G , D ∗ ) ) \begin{aligned} D_{JS}(p_{r} \| p_g) =& \frac{1}{2} D_{KL}(p_{r} || \frac{p_{r} + p_g}{2}) + \frac{1}{2} D_{KL}(p_{g} || \frac{p_{r} + p_g}{2}) \\ =& \frac{1}{2} \bigg( \log2 + \int_x p_{r}(x) \log \frac{p_{r}(x)}{p_{r} + p_g(x)} dx \bigg) + \\& \frac{1}{2} \bigg( \log2 + \int_x p_g(x) \log \frac{p_g(x)}{p_{r} + p_g(x)} dx \bigg) \\ =& \frac{1}{2} \bigg( \log4 + L(G, D^*) \bigg) \end{aligned} DJS(prpg)===21DKL(pr∣∣2pr+pg)+21DKL(pg∣∣2pr+pg)21(log2+xpr(x)logpr+pg(x)pr(x)dx)+21(log2+xpg(x)logpr+pg(x)pg(x)dx)21(log4+L(G,D))

因此 L ( G , D ∗ ) = 2 D J S ( p r ∥ p g ) − 2 log ⁡ 2 L(G, D^*) = 2D_{JS}(p_{r} \| p_g) - 2\log2 L(G,D)=2DJS(prpg)2log2
当判别器最优时,生成对抗网络 (GAN) 的损失函数使用 Jensen-Shannon 散度来衡量生成的数据分布 p g p_g pg 与真实样本分布 p r p_r pr 之间的相似性。最优生成器 G ∗ G^* G 复制真实数据分布,导致最小损失 L ( G ∗ , D ∗ ) = − 2 log ⁡ 2 L(G^*, D^*) = -2\log2 L(G,D)=2log2,与前面的方程一致。

GAN 的其他变体:存在许多 GAN 变体,专为各种环境或特定任务而定制。例如,在半监督学习中,一种方法涉及修改鉴别器以生成实际的类标签 1 、 … 、 K − 1 1、\ldots、K-1 1K1,以及单个假类标签 K K K。生成器的目标是欺骗鉴别器分配小于 K K K 的分类标签。

七、GAN中的问题

(见系列下文)…

这篇关于从GAN到WGAN(01/2)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1045821

相关文章

hdu 2602 and poj 3624(01背包)

01背包的模板题。 hdu2602代码: #include<stdio.h>#include<string.h>const int MaxN = 1001;int max(int a, int b){return a > b ? a : b;}int w[MaxN];int v[MaxN];int dp[MaxN];int main(){int T;int N, V;s

集中式版本控制与分布式版本控制——Git 学习笔记01

什么是版本控制 如果你用 Microsoft Word 写过东西,那你八成会有这样的经历: 想删除一段文字,又怕将来这段文字有用,怎么办呢?有一个办法,先把当前文件“另存为”一个文件,然后继续改,改到某个程度,再“另存为”一个文件。就这样改着、存着……最后你的 Word 文档变成了这样: 过了几天,你想找回被删除的文字,但是已经记不清保存在哪个文件了,只能挨个去找。真麻烦,眼睛都花了。看

01 Docker概念和部署

目录 1.1 Docker 概述 1.1.1 Docker 的优势 1.1.2 镜像 1.1.3 容器 1.1.4 仓库 1.2 安装 Docker 1.2.1 配置和安装依赖环境 1.3镜像操作 1.3.1 搜索镜像 1.3.2 获取镜像 1.3.3 查看镜像 1.3.4 给镜像重命名 1.3.5 存储,载入镜像和删除镜像 1.4 Doecker容器操作 1.4

生成对抗网络(GAN网络)

Generative Adversarial Nets 生成对抗网络GAN交互式可视化网站 1、GAN 基本结构 GAN 模型其实是两个网络的组合: 生成器(Generator) 负责生成模拟数据; 判别器(Discriminator) 负责判断输入的数据是真实的还是生成的。 生成器要不断优化自己生成的数据让判别网络判断不出来,判别器也要优化自己让自己判断得更准确。 二者关系形成

深度学习--对抗生成网络(GAN, Generative Adversarial Network)

对抗生成网络(GAN, Generative Adversarial Network)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN主要用于生成数据,通过两个神经网络相互对抗,来生成以假乱真的新数据。以下是对GAN的详细阐述,包括其概念、作用、核心要点、实现过程、代码实现和适用场景。 1. 概念 GAN由两个神经网络组成:生成器(Generator)和判别器(D

滚雪球学MyBatis(01):教程导读

MyBatis简介 前言 欢迎回到我们的MyBatis系列教程。在上期的内容中,我们详细介绍了MyBatis的基本概念、特点以及它与其他ORM框架(如Hibernate)的对比。我们还探讨了MyBatis在数据访问层中的优势,并解释了为什么选择MyBatis作为我们的持久化框架。在阅读了上期的内容后,相信大家对MyBatis有了初步的了解。 在本期内容中,我们将深入探讨MyBatis的基本配

python+selenium2轻量级框架设计-01框架结构

接下来会介绍一个比较简单的框架结构,先看一下分类 config文件夹里放的是配置文件 framework文件夹里面放的是公共类,常用类,还有读配置文件类、日志类、截图类、发送邮件、生成测试报告、操作读取数据库、读取Excel等,后面几篇会一一介绍 logs文件夹存放生成的日志文件 pageobject存放页面类包括元素的定位等 screenshots文件放的是生成的截图 test_

python+selenium2学习笔记POM设计模式-01模式简介

Page Object模式是Selenium中的一种测试设计模式,主要是将每一个页面设计为一个Class,其中包含页面中需要测试的元素(按钮,输入框,标题 等),这样在Selenium测试页面中可以通过调用页面类来获取页面元素,这样巧妙的避免了当页面元素id或者位置变化时,需要改测试页面代码的情况。 当页面元素id变化时,只需要更改测试页Class中页面的属性即可。 Page Object模式是

数据库学习01——mysql怎么创建数据库和表

第一步:创建数据库 使用 create database 语句,后跟要创建的数据库名称: CREATE DATABASE dbname; 例如,要创建名为 my_db 的数据库,请输入: CREATE DATABASE my_db ; 使用 show databases; 语句检查数据库是否已创建: 第二步:创建表 使用 create table 语句,后跟要创建的表名和列定

【DL--01】深度学习 揭开DL的神秘面纱

什么是深度学习 深度学习=深度神经网络+机器学习 人工智能 > 机器学习 > 表示学习 > 深度学习 神经元模型 输入信号、加权求和、加偏置、激活函数、输出 全连接层 输入信号、输入层、隐层(多个神经元)、输出层(多个输出,每个对应一个分类)、目标函数(交叉熵) 待求的参数:连接矩阵W、偏置b 训练方法:随机梯度下降,BP算法(后向传播) Python中深度学习实现:Ke