生成型对抗性网络入门实战一波流

2024-04-30 21:58

本文主要是介绍生成型对抗性网络入门实战一波流,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前几节用代码介绍了生成型对抗性网络的实现,但后来我觉得代码的实现七拐八弯,很多不必要的烦琐会增加读者的理解负担,于是花时间把代码进行强力精简,希望由此能帮助有需要的读者更顺利的入门生成型对抗性网络。

顾名思义,该网络有一种“对抗”性质。它实际上由两个子网络组成,一个网络叫生成者,一个网络叫鉴别者,后者类似于老师的作用。根据我们自己的学习经验得知,老师的作用除了告诉你“怎么做”之外,最重要的是告诉你“错在哪”,人本身有强大的模仿能力但却没有足够的纠错能力,如果在学习时有老师及时指出或纠正你的错误,那么你的学习效果将大大增加。鉴别者网络其实就是生成者的老师,他有两个个功能,一个功能是学习特定目标的内在特征,另一个功能是校正生成者的错误,让生成者不断提升对学习目标的认知能力。

举个具体实例,学生跟老师学画画,那么学生就是生成者,老师就是鉴别者。跟普通的师徒不同在于,老师一开始也不懂如何画画,他先自学一段时间,等到掌握了一定技巧后,他让学生自己先画,然后他根据自己当前的能力指出学生那里画错,学生改正后自己的能力也得到提升。接着老师继续升级自己的绘画技能,只有自己水平提高了才能更好的指导学生,于是老师自己不断进步,然后被他调教的学生也在不断进步,当老师成为大师后,如果学生画出来的话老师也挑不出错误,那么学生也成为了大师。

我们看看网络的结构图:

gan.png

我们看看如何在数学上执行“把错误信息传递给生成者”,网络本质上是一个函数,他接收输入数据然后给出输出,真实图像其实对应二维数组,鉴别者网络接收该数组后输出一个值,0表示图像来自生成者,1表示图像来自真实图像。一开始我们将真实图像输入鉴别者网络,调整期内部参数,让输出结果尽可能趋近与1,然后将生成者生成的图片输入鉴别者网络,调整其内部参数让它输出结果尽可能接近0,这样生成的图像和真实图像相应的信息就会被“寄存”在鉴别者网络的内部参数。

鉴别者如何“调教”生成者呢,这里需要借鉴间套函数求导的思路。对于函数D(G(z))中的变量z求导时结果为D’(G(z))*G’(z),如果我们把G对应生成者,D对应鉴别者,那么D’(G(z))就等价于鉴别者网络告诉生成者“错在哪”,G’(z)对应生成者自己知道错在哪,于是两种信息结合在一起就能让生成者调整内部参数,使得它的输出越来越能通过鉴别者的识别,由于鉴别者经过训练后能准确识别真实图像,如果生成者的生成图像能通过识别,那意味着生成者的生成结果越来越接近真实图像,接下来我们看看代码实现,首先我们使用谷歌提供的一笔画图像数据来进行训练,其获取路径在本课堂附件或是如下链接:

链接:https://pan.baidu.com/s/11Urnrd8QoALLnxaDlu0YPA 密码:1qqk

首先使用代码加载图片资源:

import numpy as np
import os
from os import walk
def  load_data(path):txt_name_list = []for (dirpath, dirnames, filenames) in walk(path) :#遍历给定目录下所有文件和子目录for f in filenames:if f != '.DS_Store':txt_name_list.append(f)breakslice_train = int(80000/len(txt_name_list))i = 0seed = np.random.randint(1, 10e6)for txt_name in txt_name_list:txt_path = os.path.join(path, txt_name) #获得文件完全路径x = np.load(txt_path)#加载npy文件x = (x.astype('float32') - 127.5) / 127.5 #将数值转换为[0,1]之间x = x.reshape(x.shape[0], 28, 28, 1) #将数值转换为图片规格y = [i] * len(x)np.random.seed(seed)np.random.shuffle(x)np.random.seed(seed)np.random.shuffle(y)x = x[: slice_train]y = y[: slice_train]if i != 0:xtotal = np.concatenate((x, xtotal), axis = 0)ytotal = np.concatenate((y, ytotal), axis = 0)else:xtotal = xytotal = yi += 1return xtotal, ytotalpath = '/content/drive/My Drive/camel/dataset'
(x_train, y_train) = load_data(path)
print(x_train.shape)
import matplotlib.pyplot as plt
print(np.shape(x_train[200, :,:,:]))
plt.imshow(x_train[200, :,:,0], cmap = 'gray')

上面代码执行后生成图像如下:

camel.png

我们的任务就是训练生成者网络,让它学会绘制上面分割的图像。下面我们看看两个网络的实现代码:

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import timefrom IPython import displayBUFFER_SIZE = 80000
BATCH_SIZE = 256
EPOCHS = 100
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)class Model(tf.keras.Model):def  __init__(self):super(Model, self).__init__()self.model_name = "Model"self.model_layers = []def  call(self, x):x = tf.convert_to_tensor(x, dtype = tf.float32)for layer in self.model_layers:x = layer(x)return x   class  Generator(Model):def  __init__(self):super(Generator, self).__init__()self.model_name = "generator"self.generator_layers = []self.generator_layers.append(tf.keras.layers.Dense(7*7*256, use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Reshape((7, 7, 256)))self.generator_layers.append(tf.keras.layers.Conv2DTranspose(128, (5, 5), padding = 'same', use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(64, (5,5), strides = (2,2),padding = 'same',use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(1, (5,5), strides = (2,2),padding = 'same',use_bias = False, activation = 'tanh'))self.model_layers = self.generator_layers  def  create_variables(self, z_dim):x =  np.random.normal(0, 1, (1, z_dim))x = self.call(x)class Discriminator(Model):def __init__(self):super(Discriminator, self).__init__()self.model_name = "discriminator"self.discriminator_layers = []self.discriminator_layers.append(tf.keras.layers.Conv2D(64, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Conv2D(128, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Flatten())self.discriminator_layers.append(tf.keras.layers.Dense(1)) self.model_layers = self.discriminator_layersdef  create_variables(self): #必须要调用一次call网络才会实例化x = np.expand_dims(x_train[200, :,:,:], axis = 0)self.call(x)

代码中的网络层需要简单描述一下,Conv2D实际上是将维度高,数量大的数据转换为维度第,数量小的数据,例如给定一个含有100个元素的向量,如果将其乘以维度为(80, 100)的矩阵,那么所得结果就是含有80个元素的向量,于是向量的维度或分量个数减少了,因此它的作用是将输入的二维数据不断缩小,抽取其内在规律的“精华”,而Conv2DTranspose相反,它增大输入数据的维度或分量个数,例如一维向量含有80个分量,那么乘以维度为(100,80)的数组后得到含有100个分量的向量,该函数做的就是这个工作,只不过用于相乘的矩阵里面的分量要经过训练得到。

接下来我们看训练流程:

class GAN():def  __init__(self, z_dim):self.epoch = 0self.z_dim = z_dim  #关键向量的维度#设置生成者和鉴别者网络的优化函数self.discriminator_optimizer = tf.train.AdamOptimizer(1e-4)self.generator_optimizer = tf.train.AdamOptimizer(1e-4)self.generator = Generator()self.generator.create_variables(z_dim)self.discriminator = Discriminator()self.discriminator.create_variables()self.seed = tf.random.normal([16, z_dim])def train_discriminator(self, image_batch):'''训练鉴别师网络,它的训练分两步骤,首先是输入正确图片,让网络有识别正确图片的能力。然后使用生成者网络构造图片,并告知鉴别师网络图片为假,让网络具有识别生成者网络伪造图片的能力'''with tf.GradientTape(watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数tape.watch(self.discriminator.trainable_variables)noise = tf.random.normal([len(image_batch), self.z_dim])start = time.time()true_logits = self.discriminator(image_batch, training = True)gen_imgs = self.generator(noise, training = True) #让生成者网络根据关键向量生成图片fake_logits = self.discriminator(gen_imgs, training = True)d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(true_logits), logits = true_logits)d_loss_fake =  tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.zeros_like(fake_logits), logits = fake_logits)d_loss = d_loss_real + d_loss_fakegrads = tape.gradient(d_loss , self.discriminator.trainable_variables)self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables)) #改进鉴别者网络内部参数 def  train_generator(self, batch_size): #训练生成者网络'''生成者网络训练的目的是让它生成的图像尽可能通过鉴别者网络的审查'''with tf.GradientTape(watch_accessed_variables=False) as tape: #只能修改生成者网络的内部参数不能修改鉴别者网络的内部参数tape.watch(self.generator.trainable_variables)noise = tf.random.normal([batch_size, self.z_dim])gen_imgs = self.generator(noise, training = True) #生成伪造的图片d_logits = self.discriminator(gen_imgs,training = True)verify_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.ones_like(d_logits),logits = d_logits)grads = tape.gradient(verify_loss, self.generator.trainable_variables) #调整生成者网络内部参数使得它生成的图片尽可能通过鉴别者网络的识别self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))@tf.functiondef  train_step(self, image_batch):self.train_discriminator(image_batch)self.train_generator(len(image_batch))def  train(self, epochs, run_folder):#启动训练流程for  epoch in range(EPOCHS):start = time.time()self.epoch = epochfor image_batch in train_dataset:self.train_step(image_batch)display.clear_output(wait=True)self.sample_images(run_folder) #将生成者构造的图像绘制出来self.save_model(run_folder) #存储两个网络的内部参数print("time for epoc:{} is {} seconds".format(epoch, time.time() - start))def  sample_images(self, run_folder): #绘制生成者构建的图像predictions = self.generator(self.seed)predictions = predictions.numpy()fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')plt.axis('off')plt.savefig('/content/drive/My Drive/camel/images/sample{:04d}.png'.format(self.epoch))plt.show()def  save_model(self, run_folder): #保持网络内部参数self.discriminator.save_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.save_weights(os.path.join(run_folder, 'generator.h5'))def  load_model(self, run_folder):self.discriminator.load_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.load_weights(os.path.join(run_folder, 'generator.h5'))gan = GAN(z_dim = 100)
gan.train(epochs = EPOCHS, run_folder = '/content/drive/My Drive/camel')          

注意到train_discriminator函数中,训练鉴别者网络时它需要接受两种数据,一种来自真实图像,一种来自生成者网络的图像,它要训练的识别真实图像时返回值越来越接近于1,识别生成者图像时输出结果越来越接近0.在train_generator函数中,代码先让生成者生成图像,然后把生成的图像输入鉴别者,这就类似于前面提到的间套函数,然后调整生成者内部参数,使得它生成的数据输入鉴别者后,后者输出的结果要尽可能的接近1,如此一来生成者产生的图像才可能越来越接近真实图像。这里还需要非常注意的是在调用网络时,一定要将training参数设置为True,这是因为我们在构造网络时使用了两个特殊网络层,分别是BatchNormalization,和Dropout,这两个网络层对网络的训练稳定性至关重要,如果不设置training参数为True,框架就不会执行这两个网络对应的运算,这样就会导致训练识别,笔者在开始时没有注意这个问题,因此在调试上浪费了很多时间。

上面代码运行半个小时后输出结果如下:

屏幕快照 2020-03-16 下午6.04.30.png

从生成图片结果看,生成者构造的图片与前面加载显示的真实图片其实没有太大区别。

更详细的讲解和代码调试演示过程,请点击链接](https://study.163.com/provider/7600199/course.htm?share=2&shareId=7600199)

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于生成型对抗性网络入门实战一波流的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950023

相关文章

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java利用docx4j+Freemarker生成word文档

《Java利用docx4j+Freemarker生成word文档》这篇文章主要为大家详细介绍了Java如何利用docx4j+Freemarker生成word文档,文中的示例代码讲解详细,感兴趣的小伙伴... 目录技术方案maven依赖创建模板文件实现代码技术方案Java 1.8 + docx4j + Fr

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis

MyBatis 动态 SQL 优化之标签的实战与技巧(常见用法)

《MyBatis动态SQL优化之标签的实战与技巧(常见用法)》本文通过详细的示例和实际应用场景,介绍了如何有效利用这些标签来优化MyBatis配置,提升开发效率,确保SQL的高效执行和安全性,感... 目录动态SQL详解一、动态SQL的核心概念1.1 什么是动态SQL?1.2 动态SQL的优点1.3 动态S

Pandas使用SQLite3实战

《Pandas使用SQLite3实战》本文主要介绍了Pandas使用SQLite3实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录1 环境准备2 从 SQLite3VlfrWQzgt 读取数据到 DataFrame基础用法:读

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为