生成型对抗性网络入门实战一波流

2024-04-30 21:58

本文主要是介绍生成型对抗性网络入门实战一波流,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前几节用代码介绍了生成型对抗性网络的实现,但后来我觉得代码的实现七拐八弯,很多不必要的烦琐会增加读者的理解负担,于是花时间把代码进行强力精简,希望由此能帮助有需要的读者更顺利的入门生成型对抗性网络。

顾名思义,该网络有一种“对抗”性质。它实际上由两个子网络组成,一个网络叫生成者,一个网络叫鉴别者,后者类似于老师的作用。根据我们自己的学习经验得知,老师的作用除了告诉你“怎么做”之外,最重要的是告诉你“错在哪”,人本身有强大的模仿能力但却没有足够的纠错能力,如果在学习时有老师及时指出或纠正你的错误,那么你的学习效果将大大增加。鉴别者网络其实就是生成者的老师,他有两个个功能,一个功能是学习特定目标的内在特征,另一个功能是校正生成者的错误,让生成者不断提升对学习目标的认知能力。

举个具体实例,学生跟老师学画画,那么学生就是生成者,老师就是鉴别者。跟普通的师徒不同在于,老师一开始也不懂如何画画,他先自学一段时间,等到掌握了一定技巧后,他让学生自己先画,然后他根据自己当前的能力指出学生那里画错,学生改正后自己的能力也得到提升。接着老师继续升级自己的绘画技能,只有自己水平提高了才能更好的指导学生,于是老师自己不断进步,然后被他调教的学生也在不断进步,当老师成为大师后,如果学生画出来的话老师也挑不出错误,那么学生也成为了大师。

我们看看网络的结构图:

gan.png

我们看看如何在数学上执行“把错误信息传递给生成者”,网络本质上是一个函数,他接收输入数据然后给出输出,真实图像其实对应二维数组,鉴别者网络接收该数组后输出一个值,0表示图像来自生成者,1表示图像来自真实图像。一开始我们将真实图像输入鉴别者网络,调整期内部参数,让输出结果尽可能趋近与1,然后将生成者生成的图片输入鉴别者网络,调整其内部参数让它输出结果尽可能接近0,这样生成的图像和真实图像相应的信息就会被“寄存”在鉴别者网络的内部参数。

鉴别者如何“调教”生成者呢,这里需要借鉴间套函数求导的思路。对于函数D(G(z))中的变量z求导时结果为D’(G(z))*G’(z),如果我们把G对应生成者,D对应鉴别者,那么D’(G(z))就等价于鉴别者网络告诉生成者“错在哪”,G’(z)对应生成者自己知道错在哪,于是两种信息结合在一起就能让生成者调整内部参数,使得它的输出越来越能通过鉴别者的识别,由于鉴别者经过训练后能准确识别真实图像,如果生成者的生成图像能通过识别,那意味着生成者的生成结果越来越接近真实图像,接下来我们看看代码实现,首先我们使用谷歌提供的一笔画图像数据来进行训练,其获取路径在本课堂附件或是如下链接:

链接:https://pan.baidu.com/s/11Urnrd8QoALLnxaDlu0YPA 密码:1qqk

首先使用代码加载图片资源:

import numpy as np
import os
from os import walk
def  load_data(path):txt_name_list = []for (dirpath, dirnames, filenames) in walk(path) :#遍历给定目录下所有文件和子目录for f in filenames:if f != '.DS_Store':txt_name_list.append(f)breakslice_train = int(80000/len(txt_name_list))i = 0seed = np.random.randint(1, 10e6)for txt_name in txt_name_list:txt_path = os.path.join(path, txt_name) #获得文件完全路径x = np.load(txt_path)#加载npy文件x = (x.astype('float32') - 127.5) / 127.5 #将数值转换为[0,1]之间x = x.reshape(x.shape[0], 28, 28, 1) #将数值转换为图片规格y = [i] * len(x)np.random.seed(seed)np.random.shuffle(x)np.random.seed(seed)np.random.shuffle(y)x = x[: slice_train]y = y[: slice_train]if i != 0:xtotal = np.concatenate((x, xtotal), axis = 0)ytotal = np.concatenate((y, ytotal), axis = 0)else:xtotal = xytotal = yi += 1return xtotal, ytotalpath = '/content/drive/My Drive/camel/dataset'
(x_train, y_train) = load_data(path)
print(x_train.shape)
import matplotlib.pyplot as plt
print(np.shape(x_train[200, :,:,:]))
plt.imshow(x_train[200, :,:,0], cmap = 'gray')

上面代码执行后生成图像如下:

camel.png

我们的任务就是训练生成者网络,让它学会绘制上面分割的图像。下面我们看看两个网络的实现代码:

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import timefrom IPython import displayBUFFER_SIZE = 80000
BATCH_SIZE = 256
EPOCHS = 100
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)class Model(tf.keras.Model):def  __init__(self):super(Model, self).__init__()self.model_name = "Model"self.model_layers = []def  call(self, x):x = tf.convert_to_tensor(x, dtype = tf.float32)for layer in self.model_layers:x = layer(x)return x   class  Generator(Model):def  __init__(self):super(Generator, self).__init__()self.model_name = "generator"self.generator_layers = []self.generator_layers.append(tf.keras.layers.Dense(7*7*256, use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Reshape((7, 7, 256)))self.generator_layers.append(tf.keras.layers.Conv2DTranspose(128, (5, 5), padding = 'same', use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(64, (5,5), strides = (2,2),padding = 'same',use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(1, (5,5), strides = (2,2),padding = 'same',use_bias = False, activation = 'tanh'))self.model_layers = self.generator_layers  def  create_variables(self, z_dim):x =  np.random.normal(0, 1, (1, z_dim))x = self.call(x)class Discriminator(Model):def __init__(self):super(Discriminator, self).__init__()self.model_name = "discriminator"self.discriminator_layers = []self.discriminator_layers.append(tf.keras.layers.Conv2D(64, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Conv2D(128, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Flatten())self.discriminator_layers.append(tf.keras.layers.Dense(1)) self.model_layers = self.discriminator_layersdef  create_variables(self): #必须要调用一次call网络才会实例化x = np.expand_dims(x_train[200, :,:,:], axis = 0)self.call(x)

代码中的网络层需要简单描述一下,Conv2D实际上是将维度高,数量大的数据转换为维度第,数量小的数据,例如给定一个含有100个元素的向量,如果将其乘以维度为(80, 100)的矩阵,那么所得结果就是含有80个元素的向量,于是向量的维度或分量个数减少了,因此它的作用是将输入的二维数据不断缩小,抽取其内在规律的“精华”,而Conv2DTranspose相反,它增大输入数据的维度或分量个数,例如一维向量含有80个分量,那么乘以维度为(100,80)的数组后得到含有100个分量的向量,该函数做的就是这个工作,只不过用于相乘的矩阵里面的分量要经过训练得到。

接下来我们看训练流程:

class GAN():def  __init__(self, z_dim):self.epoch = 0self.z_dim = z_dim  #关键向量的维度#设置生成者和鉴别者网络的优化函数self.discriminator_optimizer = tf.train.AdamOptimizer(1e-4)self.generator_optimizer = tf.train.AdamOptimizer(1e-4)self.generator = Generator()self.generator.create_variables(z_dim)self.discriminator = Discriminator()self.discriminator.create_variables()self.seed = tf.random.normal([16, z_dim])def train_discriminator(self, image_batch):'''训练鉴别师网络,它的训练分两步骤,首先是输入正确图片,让网络有识别正确图片的能力。然后使用生成者网络构造图片,并告知鉴别师网络图片为假,让网络具有识别生成者网络伪造图片的能力'''with tf.GradientTape(watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数tape.watch(self.discriminator.trainable_variables)noise = tf.random.normal([len(image_batch), self.z_dim])start = time.time()true_logits = self.discriminator(image_batch, training = True)gen_imgs = self.generator(noise, training = True) #让生成者网络根据关键向量生成图片fake_logits = self.discriminator(gen_imgs, training = True)d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(true_logits), logits = true_logits)d_loss_fake =  tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(fake_logits), logits = fake_logits)d_loss = d_loss_real + d_loss_fakegrads = tape.gradient(d_loss , self.discriminator.trainable_variables)self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables)) #改进鉴别者网络内部参数 def  train_generator(self, batch_size): #训练生成者网络'''生成者网络训练的目的是让它生成的图像尽可能通过鉴别者网络的审查'''with tf.GradientTape(watch_accessed_variables=False) as tape: #只能修改生成者网络的内部参数不能修改鉴别者网络的内部参数tape.watch(self.generator.trainable_variables)noise = tf.random.normal([batch_size, self.z_dim])gen_imgs = self.generator(noise, training = True) #生成伪造的图片d_logits = self.discriminator(gen_imgs,training = True)verify_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(d_logits),logits = d_logits)grads = tape.gradient(verify_loss, self.generator.trainable_variables) #调整生成者网络内部参数使得它生成的图片尽可能通过鉴别者网络的识别self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))@tf.functiondef  train_step(self, image_batch):self.train_discriminator(image_batch)self.train_generator(len(image_batch))def  train(self, epochs, run_folder):#启动训练流程for  epoch in range(EPOCHS):start = time.time()self.epoch = epochfor image_batch in train_dataset:self.train_step(image_batch)display.clear_output(wait=True)self.sample_images(run_folder) #将生成者构造的图像绘制出来self.save_model(run_folder) #存储两个网络的内部参数print("time for epoc:{} is {} seconds".format(epoch, time.time() - start))def  sample_images(self, run_folder): #绘制生成者构建的图像predictions = self.generator(self.seed)predictions = predictions.numpy()fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('/content/drive/My Drive/camel/images/sample{:04d}.png'.format(self.epoch))plt.show()def  save_model(self, run_folder): #保持网络内部参数self.discriminator.save_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.save_weights(os.path.join(run_folder, 'generator.h5'))def  load_model(self, run_folder):self.discriminator.load_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.load_weights(os.path.join(run_folder, 'generator.h5'))gan = GAN(z_dim = 100)
gan.train(epochs = EPOCHS, run_folder = '/content/drive/My Drive/camel')          

注意到train_discriminator函数中,训练鉴别者网络时它需要接受两种数据,一种来自真实图像,一种来自生成者网络的图像,它要训练的识别真实图像时返回值越来越接近于1,识别生成者图像时输出结果越来越接近0.在train_generator函数中,代码先让生成者生成图像,然后把生成的图像输入鉴别者,这就类似于前面提到的间套函数,然后调整生成者内部参数,使得它生成的数据输入鉴别者后,后者输出的结果要尽可能的接近1,如此一来生成者产生的图像才可能越来越接近真实图像。这里还需要非常注意的是在调用网络时,一定要将training参数设置为True,这是因为我们在构造网络时使用了两个特殊网络层,分别是BatchNormalization,和Dropout,这两个网络层对网络的训练稳定性至关重要,如果不设置training参数为True,框架就不会执行这两个网络对应的运算,这样就会导致训练识别,笔者在开始时没有注意这个问题,因此在调试上浪费了很多时间。

上面代码运行半个小时后输出结果如下:

屏幕快照 2020-03-16 下午6.04.30.png

从生成图片结果看,生成者构造的图片与前面加载显示的真实图片其实没有太大区别。

更详细的讲解和代码调试演示过程,请点击链接](https://study.163.com/provider/7600199/course.htm?share=2&shareId=7600199)

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于生成型对抗性网络入门实战一波流的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950023

相关文章

Golang使用minio替代文件系统的实战教程

《Golang使用minio替代文件系统的实战教程》本文讨论项目开发中直接文件系统的限制或不足,接着介绍Minio对象存储的优势,同时给出Golang的实际示例代码,包括初始化客户端、读取minio对... 目录文件系统 vs Minio文件系统不足:对象存储:miniogolang连接Minio配置Min

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

详解Java中如何使用JFreeChart生成甘特图

《详解Java中如何使用JFreeChart生成甘特图》甘特图是一种流行的项目管理工具,用于显示项目的进度和任务分配,在Java开发中,JFreeChart是一个强大的开源图表库,能够生成各种类型的图... 目录引言一、JFreeChart简介二、准备工作三、创建甘特图1. 定义数据集2. 创建甘特图3.

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

Linux 网络编程 --- 应用层

一、自定义协议和序列化反序列化 代码: 序列化反序列化实现网络版本计算器 二、HTTP协议 1、谈两个简单的预备知识 https://www.baidu.com/ --- 域名 --- 域名解析 --- IP地址 http的端口号为80端口,https的端口号为443 url为统一资源定位符。CSDNhttps://mp.csdn.net/mp_blog/creation/editor

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount