左右互搏:生成型对抗性网络的强大威力

2024-04-30 22:08

本文主要是介绍左右互搏:生成型对抗性网络的强大威力,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

生成型对抗性网络,简称GEN,在2014年时被发明。它与上一节介绍的VAE也就是编解码网络一样,擅长于图像构造,然而它的功能比VAE要强大不少,我们现在时常听到AI合成网络主播,类似功能的实现绝大多数都基于我们这次要探讨的对抗性网络。

生成型对抗性网络一个非常显著的特点是左右互搏。它由两个子网络构成,一个子网络叫generator,它负责构造图片或相应数据,另一个网络叫discriminator,负责判断前者构造数据的质量。如果前者构造的图像不够好,那么后者就传达一个负反馈给前者,于是前者根据反馈调整自身参数,让下一次生成的图片质量得以提升,它就是靠这种体内自循环的方式不断提升自己构造图片的能力。

举个例子,假设有个画家想伪造毕加索的名画,他一开始并不知道如何模仿毕加索的笔法,于是他按照自己的直觉对着毕加索一幅画进行临摹,然后把绘制结果交给一个与他串通好的绘画交易商,后者对毕加索的画颇有研究,看了临摹后给画家反馈说颜色用的太浅了。画家拿到反馈后再次临摹,这次他加深了颜色的深度,于是第二次临摹的质量比第一次好了一些。交易商看了后再次给他反馈说线条太粗了,于是画家根据反馈再次改进,这种循环不断进行,每一次循环画家模仿的记忆就变得更好,直到足够次数的改进后,画家模仿出的画与毕加索的真迹再也无法区分出来。

在这里画家就是generator,而交易商就是discriminator。在网络运行商,generator接收一个随机向量,然后输出对应一副图画的二维数组。discriminator接收二维数组,然后判断这二维数组是来自训练数据还是来自generator,如果generator生成的二维数组使得discriminator无法区分是来自训练数据还是generator生成的,整个流程结束,此时generator产生的图像与来自训练数据的图像已经相像得无法分辨了,对抗性生成型网络的运行流程如下:

屏幕快照 2019-02-22 上午11.39.43.png

discriminator网络会输入大量训练数据进行训练,让它掌握训练数据图像特征。generator网络接收一个随机向量,然后生成一张图片给discriminator判断,如果后者判断输入图片是伪造的,它会给generator一个负反馈,然后generator根据反馈修正自身参数从而改进生成的图片质量,这个流程反复进行直到generator生成的图片被discriminator接受为止,此时generator生成的图片质量与训练discriminator所用的图片质量几乎一模一样。

我们看一个GEN用于生成图片的实例:

屏幕快照 2019-02-22 上午11.44.36.png

上图中左边是真实人物图像,右边是GAN网络生成的图像,你是否感觉到网络的构造能力非常惊人。GAN网络与其他网络不通之处在于,它训练过程非常困难,因为它是两个子网络互相联动,因此网络训练时,如果调整不好,整个网络状态会一直剧烈波动无法达到平衡态。

我们接下来将尝试开发一个形态最简单的GAN网络叫DCGAN,其中子网络generator由多个卷积层组成,而discrimator由多个反卷积层组成。我们选取数据集CIFAR10对网络进行训练,它包含50000张格式为32*32的RGB图片,我们从中间抽取出所有青蛙图片训练网络,让网络学会如何无中生有的构造出以假乱真的青蛙图片。

首先我们看看generator网络的实现:

import  keras
from  keras  import  layers
import numpy as np#输入generator网络的随机向量长度
latent_dim = 32
#generator输出格式为[32, 32 , 3]的数组,它对应一张图片
height = 32
width = 32
channels = 3generator_input = keras.Input(shape = (latent_dim, ))x = layers.Dense(128 * 16 * 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)
x = layers.Conv2D(256, 5, padding= 'same')(x)
#我们使用激活函数LeakyReLu而不是以前的Relu,前者有利于网络训练时趋于稳定
x = layers.LeakyReLU()(x)#卷积网络层
x = layers.Conv2DTranspose(256, 4, strides = 2, padding = 'same')(x)
x = layers.LeakyReLU()(x)x = layers.Conv2D(256, 5, padding = 'same')(x)
x = layers.LeakyReLU()(x)x = layers.Conv2D(256, 5, padding = 'same')(x)
x = layers.LeakyReLU()(x)#使用激活函数tanh而不是sigmoid,因为它有利于网络在训练时趋于稳定
x = layers.Conv2D(channels, 7, activation = 'tanh', padding = 'same')(x)generator = keras.models.Model(generator_input, x)
generator.summary()

上面代码运行后结果如下:

屏幕快照 2019-02-22 下午5.02.40.png

接下来我们看看discriminator网络的实现:

#generator的输出就是discriminator的输入
discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides = 2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides = 2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides = 2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
#增加Dropout有利于网络训练时趋于稳定
x = layers.Dropout(0.4)(x)
x = layers.Dense(1, activation = 'sigmoid')(x)discriminator = keras.models.Model(discriminator_input, x)
discriminator.summary()discriminator_optimizer = keras.optimizers.RMSprop(lr = 0.0008, #因为网络训练时状态变化很剧烈,因此我们限定参数变化范围减少状态摇摆clipvalue = 1.0,#学习率也要不断变化以适应网络状态的改变decay = 1e-8)
#discriminator判断generator构造的图片是否为真
discriminator.compile(optimizer = discriminator_optimizer, loss = 'binary_crossentropy')

上面构造了discriminator网络,然后我们需要把两个网络连接成一个整体。网络训练的目的就是不断改进generator,让它生成的图片能骗过discriminator。两者连接成整体的代码如下:

'''
我们把generator和discriminator连成一个整体,在对整体进行训练时,
只更改generator网络的参数,discriminator的参数保持不变
'''
discriminator.trainable = False
gan_input = keras.Input(shape = (latent_dim, ))#将两个网络衔接在一起
gan_output = discriminator(generator(gan_input))gan = keras.models.Model(gan_input, gan_output)gan_optimizer = keras.optimizers.RMSprop(lr = 0.0004, clipvalue = 1.0,decay = 1e-8)
gan.compile(optimizer = gan_optimizer, loss = 'binary_crossentropy')

接着我们准备启动训练流程。训练流程分几步走,首先随机生成一个含有32个元素的一维向量,使用该向量输入generator网络,让它生成[32, 32 3]的二维数组;将生成的二维数组与来自训练图片对应的二维数组混合在一起;把混合的数据用于训练discriminator网络,其中来自训练数据的图片数组对应标签为True,来自generator产生的二维数组对应的标签为False;再次产生一个含有32个元素的一维向量,让generator产生对应的二维数组;让discriminator网络判断该二维数组是否为来自训练数据的图片,generator根据反馈修正参数改进二维数组的生成质量,这个过程一直持续到discriminator返回True为止。

我们看看相应代码:

import os
from keras.preprocessing import image(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
#选出所有青蛙图片
x_train = x_train[y_train.flatten() == 6]x_train = x_train.reshape((x_train.shape[0], ) + (height, width, channels)).astype('float32') / 255.
iterations = 10000
batch_size = 20
save_dir = '/content/gdrive/My Drive/gen_imgs'start = 0
for step in range(iterations):random_latent_vectors = np.random.normal(size = (batch_size, latent_dim))#让generator产生对应图片的二维数组generated_images = generator.predict(random_latent_vectors)stop = start + batch_sizereal_images = x_train[start : stop]combined_images = np.concatenate([generated_images, real_images])labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])#这是一个让网络训练趋于稳定的小技巧,就是将给标签添加随机化噪音labels += 0.05 * np.random.random(labels.shape)#先训练discriminator识别真假图片d_loss = discriminator.train_on_batch(combined_images, labels)random_latent_vectors = np.random.normal(size = (batch_size, latent_dim))misleading_targets = np.zeros((batch_size, 1))#根据discriminator的反馈让generator改进自身参数a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)start += batch_sizeif start > len(x_train) - batch_size:start = 0if step % 100 == 0:gan.save_weights('gan.h5')print('discriminator loss: ', d_loss)print('adversarial loss: ', a_loss)img = image.array_to_img(generated_images[0] * 255. , scale = False)img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))img = image.array_to_img(real_images[0] * 255., scale = False)img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))

在没有GPU加持的情况下,上面代码的训练会较为缓慢,当网络训练成果后,我们看看网络构造的图片和来自训练数据集的图片有何区别:

generated_frog9900.png

real_frog9800.png

由于我们生成的图片很小不好观察,但把两只图片放在一起对比一下,上面图片是网络生成的青蛙图片,下边图片是来自训练数据集的图片,我们不难体会到,网络生成的图片跟来自训练数据集的图片几乎看不出区别来。

最近看到一则新闻说,搜狗与央视合作,使用人工智能合成新闻主播名叫小萌,她的原型来自于央视的一名主持人,这名AI合成主持人已经做到人眼看不出她是虚拟的,不论是举手投足还是细微表情的展现上,都与真人无异,我想搜狗所用的技术,应该就是我们今天谈到的生成型对抗性网络。

生成型对抗性网络是我们接触的所有类型网络中最为复杂的一种。它在训练过程中,只要参数稍微不对,整个网络就不能收敛,GAN网络的训练和开发几乎没有什么原理来指导,出现异常情况时,要靠开发者自身的经验和直觉去处理或调整,这里只能作为抛砖引玉之用,有兴趣的读者可以自行加大探索的力度。

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

更多内容,请点击进入csdn学院

这篇关于左右互搏:生成型对抗性网络的强大威力的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/950049

相关文章

MySQL中动态生成SQL语句去掉所有字段的空格的操作方法

《MySQL中动态生成SQL语句去掉所有字段的空格的操作方法》在数据库管理过程中,我们常常会遇到需要对表中字段进行清洗和整理的情况,本文将详细介绍如何在MySQL中动态生成SQL语句来去掉所有字段的空... 目录在mysql中动态生成SQL语句去掉所有字段的空格准备工作原理分析动态生成SQL语句在MySQL

Java利用docx4j+Freemarker生成word文档

《Java利用docx4j+Freemarker生成word文档》这篇文章主要为大家详细介绍了Java如何利用docx4j+Freemarker生成word文档,文中的示例代码讲解详细,感兴趣的小伙伴... 目录技术方案maven依赖创建模板文件实现代码技术方案Java 1.8 + docx4j + Fr

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

在java中如何将inputStream对象转换为File对象(不生成本地文件)

《在java中如何将inputStream对象转换为File对象(不生成本地文件)》:本文主要介绍在java中如何将inputStream对象转换为File对象(不生成本地文件),具有很好的参考价... 目录需求说明问题解决总结需求说明在后端中通过POI生成Excel文件流,将输出流(outputStre

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用