深入理解GAN网络

2024-09-06 03:04
文章标签 深入 理解 网络 gan

本文主要是介绍深入理解GAN网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Generative Adversarial Networks创造性地提出了对抗训练来代替人工指定的loss。之前的文章初步理解了一下,感觉还是不到位,在这里再稍微深入一点。

交叉熵cross entropy

鉴别器是GAN中重要的一部分,鉴别器和生成器交替训练的过程就是adversarial training的过程。而鉴别器的本质是一个二分类网络,所以要理解gan loss,就首先要熟悉二分类中经典的交叉熵。

交叉熵中的“交叉”指的是对数部分使用的概率分布和加权求和部分使用的概率分布不同。具体而言,对数部分使用的是预测的结果,而加权求和部分使用的是真实标签的分布。所以二分类的交叉熵为:

注意上面公式中的.表示乘积,并且等式右边应该整体取负数。

鉴别器discriminator 

当鉴别器输入为真实数据, 

预测结果标签交叉熵的负数
D(X)1

当鉴别器输入为生成器的输出,

预测结果标签交叉熵的负数
D(G(Z))0

这两个交叉熵取负的曲线如下所示:

因为上述曲线没有考虑负号,所以应该是取最大值时损失函数最小,结合横坐标取值范围是0~1,所以鉴别器的优化目标是上图中箭头所指的位置。对应公式为:

生成器Generator 

生成器和鉴别器做对抗,所以优化目标正好是上式最小:

但其实生成器只参与上式中的第二项。可以看到,为了使第二项最小,D(G(Z))要逼近1,这意味着生成器的目标是让鉴别器认为生成器的结果是真。

The Generator needs to fool the discriminator by generating images as real as possible.

需要说明的是,虽然生成器G的目标是log(1 − D(G(z)))和D(G(z))=1是一致的,但一开始生成器很弱,鉴别器可以轻松区分开来,所以在训练前期可以使用D(G(z))来获得更大的梯度。

进一步理解

交替训练

把生成器和鉴别器的优化目标合并在一起,可以写为:

这个公式优美地表示了生成器和鉴别器的优化方向是共线的,只不过方向正好相反。既然方向相反,那就不能同步训练,而是要交替进行:

The adversarial loss can be optimized by gradient descent. But while training a GAN we do not train the generator and discriminator simultaneously, while training the generator we freeze the discriminator and vice-versa!

不平衡

整体上看D和G是交替训练,但是比例却不同。D更新时有k step,G却只有1 step。这样只要G更新足够慢,D就可以一直保持在最优解附近。

这其实和受限玻尔兹曼机(Restricted Boltzmann Machine)有关。相关中文资料,英文资料。

受限玻尔兹曼机由多伦多大学的 Geoff Hinton 等人提出,是一种可通过输入数据集学习概率分布的随机生成神经网络。随机指的是如果神经元被激活时会产生随机行为;受限指的是同一层的不同节点之间是不进行通信的,而不同层之间会有信息双向流动,并且权值相同。

但是到这里仍然没有为什么D需要训练k step,有更懂的朋友可以解释一下。我的理解是D对于整个GAN网络的效果很关键,是需要维持在一个很高的水平。但一开始G很弱,D可以很轻松区分开来,没必要训练那么多次。所以一种可取的策略是先训练生成器k1次,然后再训练判别器k2次,然后再交替训练。

损失函数

我们知道有很多映射关系是无法准确写出表达式的,而CNN的本质是基于 Universal Approximation Theorem,可以实现对任意函数的拟合。为了去衡量CNN的好坏,还是需要Explicit loss function来指导CNN中的梯度更新。

Explicit loss function的问题是它需要人工指定,并且是和任务相关的,比如:

L1 loss (Absolute error)

Used for regression task

L2 loss (Squared error)

Similar to L1 but more sensitive to outliers.

Cross-entropy loss

Usually used in classification tasks.

Dice loss (IoU)

Used in segmentation tasks.

KL Divergence: 

For measuring the difference between two distributions.

当任务更加复杂时,比如对黑白图像上色,传统的loss 函数会遇到各种问题。如MSE倾向于求平均的特性,预测结果是几种可能性的加权结果,带来的问题就是要么图像变模糊,要么出现不切实际的结果(两个可能性是真实存在的,强行加权求和的结果是不存在的)。

既然loss 函数也是函数,那么和之前的思想一样,也可以再使用一个CNN来拟合。这其实就是GAN的思想。判别网络潜在学习到的损失函数隐藏在网络之中,不同的问题这个函数就不一样,所以说可以自动学习这个潜在的损失函数 learned loss function,而不用人为地根据任务去调整。

学习分布

GAN新加了一个CNN去拟合损失函数,而这个CNN是一个分类网络。分类网络的损失函数是交叉熵,衡量两个分布的差异,所以可以认为整个GAN其实就是在学习一个概率分布。

D的本质是比较两个分布的差异,所以G的输出本质也是一种分布,而G的输入可以是均匀分布的噪声,所以G的学习目标其实是学习分布的迁移。可以看到,G的输入是均匀分布的z,G的输出x的分布绿色曲线一点点在向真实正态分布靠近。

所以说GAN是在学习给定GT的分布。

https://www.slideshare.net/slideshow/improved-trainings-of-wasserstein-gans-wgangp/104146363#4

reference:

1. https://zhuanlan.zhihu.com/p/149186719

2.Understanding GANs — Deriving the Adversarial loss from scratch

3.https://openai.com/index/generative-models/

4.https://medium.com/vitalify-asia/gans-as-a-loss-function-72d994dde4fb

这篇关于深入理解GAN网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1140821

相关文章

一文带你理解Python中import机制与importlib的妙用

《一文带你理解Python中import机制与importlib的妙用》在Python编程的世界里,import语句是开发者最常用的工具之一,它就像一把钥匙,打开了通往各种功能和库的大门,下面就跟随小... 目录一、python import机制概述1.1 import语句的基本用法1.2 模块缓存机制1.

深入理解C语言的void*

《深入理解C语言的void*》本文主要介绍了C语言的void*,包括它的任意性、编译器对void*的类型检查以及需要显式类型转换的规则,具有一定的参考价值,感兴趣的可以了解一下... 目录一、void* 的类型任意性二、编译器对 void* 的类型检查三、需要显式类型转换占用的字节四、总结一、void* 的

深入理解Redis大key的危害及解决方案

《深入理解Redis大key的危害及解决方案》本文主要介绍了深入理解Redis大key的危害及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、背景二、什么是大key三、大key评价标准四、大key 产生的原因与场景五、大key影响与危

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

深入理解C++ 空类大小

《深入理解C++空类大小》本文主要介绍了C++空类大小,规定空类大小为1字节,主要是为了保证对象的唯一性和可区分性,满足数组元素地址连续的要求,下面就来了解一下... 目录1. 保证对象的唯一性和可区分性2. 满足数组元素地址连续的要求3. 与C++的对象模型和内存管理机制相适配查看类对象内存在C++中,规

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

认识、理解、分类——acm之搜索

普通搜索方法有两种:1、广度优先搜索;2、深度优先搜索; 更多搜索方法: 3、双向广度优先搜索; 4、启发式搜索(包括A*算法等); 搜索通常会用到的知识点:状态压缩(位压缩,利用hash思想压缩)。

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor