生成型对抗性网络入门实战一波流

2024-04-30 21:58

本文主要是介绍生成型对抗性网络入门实战一波流,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前几节用代码介绍了生成型对抗性网络的实现,但后来我觉得代码的实现七拐八弯,很多不必要的烦琐会增加读者的理解负担,于是花时间把代码进行强力精简,希望由此能帮助有需要的读者更顺利的入门生成型对抗性网络。

顾名思义,该网络有一种“对抗”性质。它实际上由两个子网络组成,一个网络叫生成者,一个网络叫鉴别者,后者类似于老师的作用。根据我们自己的学习经验得知,老师的作用除了告诉你“怎么做”之外,最重要的是告诉你“错在哪”,人本身有强大的模仿能力但却没有足够的纠错能力,如果在学习时有老师及时指出或纠正你的错误,那么你的学习效果将大大增加。鉴别者网络其实就是生成者的老师,他有两个个功能,一个功能是学习特定目标的内在特征,另一个功能是校正生成者的错误,让生成者不断提升对学习目标的认知能力。

举个具体实例,学生跟老师学画画,那么学生就是生成者,老师就是鉴别者。跟普通的师徒不同在于,老师一开始也不懂如何画画,他先自学一段时间,等到掌握了一定技巧后,他让学生自己先画,然后他根据自己当前的能力指出学生那里画错,学生改正后自己的能力也得到提升。接着老师继续升级自己的绘画技能,只有自己水平提高了才能更好的指导学生,于是老师自己不断进步,然后被他调教的学生也在不断进步,当老师成为大师后,如果学生画出来的话老师也挑不出错误,那么学生也成为了大师。

我们看看网络的结构图:

gan.png

我们看看如何在数学上执行“把错误信息传递给生成者”,网络本质上是一个函数,他接收输入数据然后给出输出,真实图像其实对应二维数组,鉴别者网络接收该数组后输出一个值,0表示图像来自生成者,1表示图像来自真实图像。一开始我们将真实图像输入鉴别者网络,调整期内部参数,让输出结果尽可能趋近与1,然后将生成者生成的图片输入鉴别者网络,调整其内部参数让它输出结果尽可能接近0,这样生成的图像和真实图像相应的信息就会被“寄存”在鉴别者网络的内部参数。

鉴别者如何“调教”生成者呢,这里需要借鉴间套函数求导的思路。对于函数D(G(z))中的变量z求导时结果为D’(G(z))*G’(z),如果我们把G对应生成者,D对应鉴别者,那么D’(G(z))就等价于鉴别者网络告诉生成者“错在哪”,G’(z)对应生成者自己知道错在哪,于是两种信息结合在一起就能让生成者调整内部参数,使得它的输出越来越能通过鉴别者的识别,由于鉴别者经过训练后能准确识别真实图像,如果生成者的生成图像能通过识别,那意味着生成者的生成结果越来越接近真实图像,接下来我们看看代码实现,首先我们使用谷歌提供的一笔画图像数据来进行训练,其获取路径在本课堂附件或是如下链接:

链接:https://pan.baidu.com/s/11Urnrd8QoALLnxaDlu0YPA 密码:1qqk

首先使用代码加载图片资源:

import numpy as np
import os
from os import walk
def  load_data(path):txt_name_list = []for (dirpath, dirnames, filenames) in walk(path) :#遍历给定目录下所有文件和子目录for f in filenames:if f != '.DS_Store':txt_name_list.append(f)breakslice_train = int(80000/len(txt_name_list))i = 0seed = np.random.randint(1, 10e6)for txt_name in txt_name_list:txt_path = os.path.join(path, txt_name) #获得文件完全路径x = np.load(txt_path)#加载npy文件x = (x.astype('float32') - 127.5) / 127.5 #将数值转换为[0,1]之间x = x.reshape(x.shape[0], 28, 28, 1) #将数值转换为图片规格y = [i] * len(x)np.random.seed(seed)np.random.shuffle(x)np.random.seed(seed)np.random.shuffle(y)x = x[: slice_train]y = y[: slice_train]if i != 0:xtotal = np.concatenate((x, xtotal), axis = 0)ytotal = np.concatenate((y, ytotal), axis = 0)else:xtotal = xytotal = yi += 1return xtotal, ytotalpath = '/content/drive/My Drive/camel/dataset'
(x_train, y_train) = load_data(path)
print(x_train.shape)
import matplotlib.pyplot as plt
print(np.shape(x_train[200, :,:,:]))
plt.imshow(x_train[200, :,:,0], cmap = 'gray')

上面代码执行后生成图像如下:

camel.png

我们的任务就是训练生成者网络,让它学会绘制上面分割的图像。下面我们看看两个网络的实现代码:

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import timefrom IPython import displayBUFFER_SIZE = 80000
BATCH_SIZE = 256
EPOCHS = 100
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)class Model(tf.keras.Model):def  __init__(self):super(Model, self).__init__()self.model_name = "Model"self.model_layers = []def  call(self, x):x = tf.convert_to_tensor(x, dtype = tf.float32)for layer in self.model_layers:x = layer(x)return x   class  Generator(Model):def  __init__(self):super(Generator, self).__init__()self.model_name = "generator"self.generator_layers = []self.generator_layers.append(tf.keras.layers.Dense(7*7*256, use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Reshape((7, 7, 256)))self.generator_layers.append(tf.keras.layers.Conv2DTranspose(128, (5, 5), padding = 'same', use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(64, (5,5), strides = (2,2),padding = 'same',use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(1, (5,5), strides = (2,2),padding = 'same',use_bias = False, activation = 'tanh'))self.model_layers = self.generator_layers  def  create_variables(self, z_dim):x =  np.random.normal(0, 1, (1, z_dim))x = self.call(x)class Discriminator(Model):def __init__(self):super(Discriminator, self).__init__()self.model_name = "discriminator"self.discriminator_layers = []self.discriminator_layers.append(tf.keras.layers.Conv2D(64, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Conv2D(128, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Flatten())self.discriminator_layers.append(tf.keras.layers.Dense(1)) self.model_layers = self.discriminator_layersdef  create_variables(self): #必须要调用一次call网络才会实例化x = np.expand_dims(x_train[200, :,:,:], axis = 0)self.call(x)

代码中的网络层需要简单描述一下,Conv2D实际上是将维度高,数量大的数据转换为维度第,数量小的数据,例如给定一个含有100个元素的向量,如果将其乘以维度为(80, 100)的矩阵,那么所得结果就是含有80个元素的向量,于是向量的维度或分量个数减少了,因此它的作用是将输入的二维数据不断缩小,抽取其内在规律的“精华”,而Conv2DTranspose相反,它增大输入数据的维度或分量个数,例如一维向量含有80个分量,那么乘以维度为(100,80)的数组后得到含有100个分量的向量,该函数做的就是这个工作,只不过用于相乘的矩阵里面的分量要经过训练得到。

接下来我们看训练流程:

class GAN():def  __init__(self, z_dim):self.epoch = 0self.z_dim = z_dim  #关键向量的维度#设置生成者和鉴别者网络的优化函数self.discriminator_optimizer = tf.train.AdamOptimizer(1e-4)self.generator_optimizer = tf.train.AdamOptimizer(1e-4)self.generator = Generator()self.generator.create_variables(z_dim)self.discriminator = Discriminator()self.discriminator.create_variables()self.seed = tf.random.normal([16, z_dim])def train_discriminator(self, image_batch):'''训练鉴别师网络,它的训练分两步骤,首先是输入正确图片,让网络有识别正确图片的能力。然后使用生成者网络构造图片,并告知鉴别师网络图片为假,让网络具有识别生成者网络伪造图片的能力'''with tf.GradientTape(watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数tape.watch(self.discriminator.trainable_variables)noise = tf.random.normal([len(image_batch), self.z_dim])start = time.time()true_logits = self.discriminator(image_batch, training = True)gen_imgs = self.generator(noise, training = True) #让生成者网络根据关键向量生成图片fake_logits = self.discriminator(gen_imgs, training = True)d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(true_logits), logits = true_logits)d_loss_fake =  tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(fake_logits), logits = fake_logits)d_loss = d_loss_real + d_loss_fakegrads = tape.gradient(d_loss , self.discriminator.trainable_variables)self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables)) #改进鉴别者网络内部参数 def  train_generator(self, batch_size): #训练生成者网络'''生成者网络训练的目的是让它生成的图像尽可能通过鉴别者网络的审查'''with tf.GradientTape(watch_accessed_variables=False) as tape: #只能修改生成者网络的内部参数不能修改鉴别者网络的内部参数tape.watch(self.generator.trainable_variables)noise = tf.random.normal([batch_size, self.z_dim])gen_imgs = self.generator(noise, training = True) #生成伪造的图片d_logits = self.discriminator(gen_imgs,training = True)verify_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(d_logits),logits = d_logits)grads = tape.gradient(verify_loss, self.generator.trainable_variables) #调整生成者网络内部参数使得它生成的图片尽可能通过鉴别者网络的识别self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))@tf.functiondef  train_step(self, image_batch):self.train_discriminator(image_batch)self.train_generator(len(image_batch))def  train(self, epochs, run_folder):#启动训练流程for  epoch in range(EPOCHS):start = time.time()self.epoch = epochfor image_batch in train_dataset:self.train_step(image_batch)display.clear_output(wait=True)self.sample_images(run_folder) #将生成者构造的图像绘制出来self.save_model(run_folder) #存储两个网络的内部参数print("time for epoc:{} is {} seconds".format(epoch, time.time() - start))def  sample_images(self, run_folder): #绘制生成者构建的图像predictions = self.generator(self.seed)predictions = predictions.numpy()fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('/content/drive/My Drive/camel/images/sample{:04d}.png'.format(self.epoch))plt.show()def  save_model(self, run_folder): #保持网络内部参数self.discriminator.save_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.save_weights(os.path.join(run_folder, 'generator.h5'))def  load_model(self, run_folder):self.discriminator.load_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.load_weights(os.path.join(run_folder, 'generator.h5'))gan = GAN(z_dim = 100)
gan.train(epochs = EPOCHS, run_folder = '/content/drive/My Drive/camel')          

注意到train_discriminator函数中,训练鉴别者网络时它需要接受两种数据,一种来自真实图像,一种来自生成者网络的图像,它要训练的识别真实图像时返回值越来越接近于1,识别生成者图像时输出结果越来越接近0.在train_generator函数中,代码先让生成者生成图像,然后把生成的图像输入鉴别者,这就类似于前面提到的间套函数,然后调整生成者内部参数,使得它生成的数据输入鉴别者后,后者输出的结果要尽可能的接近1,如此一来生成者产生的图像才可能越来越接近真实图像。这里还需要非常注意的是在调用网络时,一定要将training参数设置为True,这是因为我们在构造网络时使用了两个特殊网络层,分别是BatchNormalization,和Dropout,这两个网络层对网络的训练稳定性至关重要,如果不设置training参数为True,框架就不会执行这两个网络对应的运算,这样就会导致训练识别,笔者在开始时没有注意这个问题,因此在调试上浪费了很多时间。

上面代码运行半个小时后输出结果如下:

屏幕快照 2020-03-16 下午6.04.30.png

从生成图片结果看,生成者构造的图片与前面加载显示的真实图片其实没有太大区别。

更详细的讲解和代码调试演示过程,请点击链接](https://study.163.com/provider/7600199/course.htm?share=2&shareId=7600199)

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于生成型对抗性网络入门实战一波流的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950023

相关文章

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

PowerShell中15个提升运维效率关键命令实战指南

《PowerShell中15个提升运维效率关键命令实战指南》作为网络安全专业人员的必备技能,PowerShell在系统管理、日志分析、威胁检测和自动化响应方面展现出强大能力,下面我们就来看看15个提升... 目录一、PowerShell在网络安全中的战略价值二、网络安全关键场景命令实战1. 系统安全基线核查

从入门到精通MySQL联合查询

《从入门到精通MySQL联合查询》:本文主要介绍从入门到精通MySQL联合查询,本文通过实例代码给大家介绍的非常详细,需要的朋友可以参考下... 目录摘要1. 多表联合查询时mysql内部原理2. 内连接3. 外连接4. 自连接5. 子查询6. 合并查询7. 插入查询结果摘要前面我们学习了数据库设计时要满

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

从原理到实战深入理解Java 断言assert

《从原理到实战深入理解Java断言assert》本文深入解析Java断言机制,涵盖语法、工作原理、启用方式及与异常的区别,推荐用于开发阶段的条件检查与状态验证,并强调生产环境应使用参数验证工具类替代... 目录深入理解 Java 断言(assert):从原理到实战引言:为什么需要断言?一、断言基础1.1 语

从入门到精通C++11 <chrono> 库特性

《从入门到精通C++11<chrono>库特性》chrono库是C++11中一个非常强大和实用的库,它为时间处理提供了丰富的功能和类型安全的接口,通过本文的介绍,我们了解了chrono库的基本概念... 目录一、引言1.1 为什么需要<chrono>库1.2<chrono>库的基本概念二、时间段(Durat

Java MQTT实战应用

《JavaMQTT实战应用》本文详解MQTT协议,涵盖其发布/订阅机制、低功耗高效特性、三种服务质量等级(QoS0/1/2),以及客户端、代理、主题的核心概念,最后提供Linux部署教程、Sprin... 目录一、MQTT协议二、MQTT优点三、三种服务质量等级四、客户端、代理、主题1. 客户端(Clien

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决