使用’推土距离‘构建强悍的WGAN

2024-04-30 21:58

本文主要是介绍使用’推土距离‘构建强悍的WGAN,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

读者读到此处时或许会有一个感触,网络训练的目的是让网络在接收输入数据后,它输出的结果在给定衡量标准上变得越来越好,由此“衡量标准”设计的好坏对网络训练最终结果产生至关重要的作用。

回想上一节,当我们把N张数据图片输入到网络后,网络会输出一个含有N个分量的向量,接着我们先构造一个含有N个1的向量,然后判断网络得出的向量与构造的含有N个1的向量是否足够“接近”。

算法判断两个向量是否接近的标准是“交叉熵”,也就是image.png
其中image.png
对应构造的含有N个1的向量中对应的分量,也就是无论i取什么值都有:
image.pngimage.png
则是网络接收第i张图片后输出其为真实图片对应的概率。当输入图片比较复杂时,使用交叉熵来衡量输出结果的好坏在数学上有严重缺陷,简单的说交叉熵不能够精确的衡量网络是否已经有效的识别出图片特征,这里我们介绍另一种衡量方法叫“推土距离”。推土距离的定义如下,假设地面两处位置上有两个形状不同的土堆,如下图所示:
image.png
,P和Q分布表示两处土堆,每个长条方块可以看做是一个小沙丘,你的任务是使用推土机将P中某个沙丘上的土搬到另一个沙丘,使得最后土堆P的形状和Q的形状一模一样。显然沙土的搬运方法有很多种,一种搬运法如图下图所示:
image.png
上图,箭头表示把沙土从箭头起始的沙丘搬运到箭头所指向的沙丘,当然还可以有另外的搬运法,如下图所示:
image.png
如图17-7所示,将土堆从箭头起始的沙丘搬运到箭头指向的沙丘,所得结果也能使土堆P向土堆Q转换,但如果我们考虑到搬运的成本,如果将搬运土堆的重量乘以土堆移动的距离作为一次搬运成本,那么不难看第一章图所示的搬运法比1第二张图所示的搬运法更节省。

所谓搬图距离就是所有可行的搬土方法中能实现成本最小的那种搬运方法,使用W(P,Q)来标记。不难看出P和Q其实可以对应两种不同的概率分布,因此推土距离本质上就是将给定概率分布P转换成概率分布Q,并且要求转换所产生的成本要尽可能小。我们可以通过下图对“推土距离”进行更形象的理解:
image.png
上图中,在P和Q之间对应一个二维矩阵,每一行对应将土堆P对应沙丘中的沙土晕倒Q中对应列所示沙丘的距离,方块的颜色越深表示表示运送沙土的数量越多,使用符号image.png
来表示上图所示矩阵,注意到它的每一行所有元素加总对应P中所在沙丘的含土量,每一列对应Q中相应沙丘的含土量,因此使用image.png
表示将土堆中Xp对应沙丘运送到Xq对应沙丘的土量,使用image.png
表示两个沙丘的距离,那么一个搬运方案就可以使用公式image.png
来表示。而推土距离就是所有可行方案中拥有最小成本那种,使用image.png
来表示,其中符号image.png
表示所有可行搬运方案的集合,推土距离是数学最优化领域中非常复杂的难题。接下来我们看看WGAN网络的数学原理,我们就可以使用搬图距离来衡量网络输出结果的好坏,算法将使用下面公式来描述Discriminator网络的损失函数:
image.png
要说明该公式能表示G,D之间的推土距离需要相当复杂的推导,在此我们暂时忽略。公式看起来似乎很复杂,读者不必要被它吓到,它要做的事情很简单。在17.1.1节中,如果图形来自于数据集,那么算法就构造全是1的向量,如果图像来自生成者网络,那么算法就 构造全是0的向量。

根据上面公式我们对算法做一些小修改,如果图像来自生成者网络,那么构造分量全是-1的分量,这意味着算法将训练Discriminator网络,使得它接收N张来自数据集的图片,输出的N个结果的平均值要尽可能大。

在公式中还有一个约束条件需要注意,那就是:
image.png
满足该条件的函数必须具备如下性质:
image.png
也就是说如果把Discriminator网络看做一个函数,那么网络输出数据的特性必须满足上面公式。但是在实践上我们无法直接构造一个网络使得它的特性满足上面公式,因此算法使用一种便宜之计就是将Discriminator网络内部参数的值限定在区间(-1,1)。“偏移之计”的做法其实并不能让鉴别者网络满足约束条件,只不过它能让算法取得较好的结果,在后面会给出更好的处理方法。接下来我们看看WGAN网络的实现,首先我们要加载训练所需的图片数据:

import numpy as np
import os
from keras.datasets import cifar10
def  load_cifar10(label):#加载keras代码库自带的cifar数据集,里面是各种物体的图片(x_train, y_train), (x_test, y_test) = cifar10.load_data()train_mask = [y[0] == label for y in y_train] #将给定标签的图片挑选出来test_mask = [y[0] == label for y in y_test]x_data = np.concatenate([x_train[train_mask], x_test[test_mask]] )y_data = np.concatenate([y_train[train_mask], y_test[test_mask]])x_data = (x_data.astype('float32') - 127.5) / 127.5 return (x_data, y_data)
CIFAR_HORSE_LABEL = 7 #图片类别由标签值对应,7对应所有马的图片   
(x_train, y_train) = load_cifar10(CIFAR_HORSE_LABEL)#加载所有马图片
import matplotlib.pyplot as plt
plt.imshow((x_train[150, :,:, :] + 1) / 2)

代码将keras库附带的数据集cifar加载到内存,该数据集对应了多种物品的的图片,每种特定物品使用标签值就行区分,代码中使用的标签值7对应所有马的图片,后面实现的WGAN将专门使用马的图片来训练,因此训练结束后网络会学会如何绘制马的图片,上面代码运行后所得结果如下图所示:
image.png
接下来构造生成者和鉴别者网络并将其拼接成一个整体:

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
BUFFER_SIZE = 6000
BATCH_SIZE = 256
EPOCHS = 12000
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
class Model(tf.keras.Model):def  __init__(self):super(Model, self).__init__()self.model_name = "Model"self.model_layers = []def  call(self, x):x = tf.convert_to_tensor(x, dtype = tf.float32)for layer in self.model_layers:x = layer(x)return x   
class  Generator(Model):def  __init__(self):super(Generator, self).__init__()self.model_name = "generator"self.generator_layers = []self.generator_layers.append(tf.keras.layers.Dense(4*4*128, use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization(momentum = 0.8))                        self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Reshape((4, 4, 128)))self.generator_layers.append(tf.keras.layers.UpSampling2D())self.generator_layers.append(tf.keras.layers.Conv2D(128, (5, 5),strides = (1,1), padding = 'same', use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.UpSampling2D()) #upSampling2D将数据通过复制的方式扩大一倍self.generator_layers.append(tf.keras.layers.Conv2D(64, (5,5), strides = (1,1),padding = 'same',use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization(momentum = 0.8))self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.UpSampling2D())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(32, (5,5), strides = (1,1),padding = 'same',use_bias = False))self.generator_layers.append(tf.keras.layers.BatchNormalization())self.generator_layers.append(tf.keras.layers.LeakyReLU())self.generator_layers.append(tf.keras.layers.Conv2DTranspose(3, (5,5), strides = (1,1),padding = 'same',use_bias = False, activation = 'tanh'))self.model_layers = self.generator_layers #最终输出数据的规格为(32,32,3)def  create_variables(self, z_dim):x =  np.random.normal(0, 1, (1, z_dim))x = self.call(x)
class Discriminator(Model):def __init__(self):#鉴别者网络卷积层的规格为(32, 64,128, 128)super(Discriminator, self).__init__()self.model_name = "discriminator"self.discriminator_layers = []self.discriminator_layers.append(tf.keras.layers.Conv2D(32, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Conv2D(64, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Dropout(0.3))self.discriminator_layers.append(tf.keras.layers.Conv2D(128, (5,5), strides = (2,2),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Conv2D(128, (5,5), strides = (1,1),padding = 'same'))self.discriminator_layers.append(tf.keras.layers.LeakyReLU())self.discriminator_layers.append(tf.keras.layers.Flatten())self.discriminator_layers.append(tf.keras.layers.Dense(1, activation = "tanh")) self.model_layers = self.discriminator_layersdef  create_variables(self): #必须要调用一次call网络才会实例化x = np.expand_dims(x_train[200, :,:,:], axis = 0)self.call(x)

读者需要注意,在代码实现中,鉴别者和生成者网络跟上一节有一些明显差异,首先鉴别者网络的卷积层输出规格变为(32, 64, 128, 128),同时去掉了Dorpout网络层,生成者网络使用Upsampling2D来扩展数据规格。
此处需要展开说明Upsampling2D网络层的操作流程,它的作用与17.1.1节使用的Conv2DTranspose一样,都是将输入数据的规格扩大一倍,但做法不同,它仅仅是将输入二维数组的元素进行复制,具体操作如下:
image.png
接下来看看网络训练过程的实现,训练流程与17.1.1节大同小异,但是有几个要点需要注意:
···
class GAN():
def init(self, z_dim):
self.epoch = 0
self.z_dim = z_dim #关键向量的维度
#设置生成者和鉴别者网络的优化函数
self.discriminator_optimizer = tf.optimizers.Adam(0.0002)
self.generator_optimizer = tf.optimizers.Adam(0.0002)
self.generator = Generator()
self.generator.create_variables(z_dim)
self.discriminator = Discriminator()
self.discriminator.create_variables()
self.seed = tf.random.normal([16, z_dim])
self.d_loss = []
self.d_loss_real = []
self.d_loss_fake = []
self.g_loss = []
self.discriminator_trains = 5
self.image_batch_count = 0
def train_discriminator(self, image_batch):
‘’’
训练鉴别师网络,它的训练分两步骤,首先是输入正确图片,让网络有识别正确图片的能力。
然后使用生成者网络构造图片,并告知鉴别师网络图片为假,让网络具有识别生成者网络伪造图片的能力
‘’’

    with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape: #只修改鉴别者网络的内部参数tape.watch(self.discriminator.trainable_variables)noise = tf.random.normal([len(image_batch), self.z_dim])true_logits = self.discriminator(image_batch, training = True)gen_imgs = self.generator(noise, training = True) #让生成者网络根据关键向量生成图片fake_logits = self.discriminator(gen_imgs, training = True)d_loss_real = tf.multiply(tf.ones_like(true_logits), true_logits)#根据推土距离将真图片的标签设置为1 d_loss_fake = tf.multiply(-tf.ones_like(fake_logits), fake_logits)#将伪造图片的标签设置为-1with tf.GradientTape(watch_accessed_variables=False) as iterploted_tape:t = tf.random.uniform(shape = (len(image_batch), 1, 1, 1)) #生成[0,1]区间的随机数interploted_imgs = tf.add(tf.multiply(1 - t, image_batch), tf.multiply(t, gen_imgs))iterploted_tape.watch(interploted_imgs)interploted_loss = self.discriminator(interploted_imgs)interploted_imgs_grads = iterploted_tape.gradient(interploted_loss, interploted_imgs)grad_norms = tf.norm(interploted_imgs_grads)penalty = 10 * tf.reduce_mean((grad_norms - 1) ** 2)d_loss = d_loss_real + d_loss_fake + penaltygrads = tape.gradient(d_loss , self.discriminator.trainable_variables)self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables)) #改进鉴别者网络内部参数 self.d_loss.append(d_loss)self.d_loss_real.append(d_loss_real)self.d_loss_fake.append(d_loss_fake)def  train_generator(self, batch_size): #训练生成者网络'''生成者网络训练的目的是让它生成的图像尽可能通过鉴别者网络的审查'''with tf.GradientTape(persistent=True,watch_accessed_variables=False) as tape: #只能修改生成者网络的内部参数不能修改鉴别者网络的内部参数tape.watch(self.generator.trainable_variables)noise = tf.random.normal([batch_size, self.z_dim])gen_imgs = self.generator(noise, training = True) #生成伪造的图片d_logits = self.discriminator(gen_imgs,training = True)g_loss = tf.multiply(tf.ones_like(d_logits), d_logits)#将标签设置为1grads = tape.gradient(g_loss, self.generator.trainable_variables) #调整生成者网络内部参数使得它生成的图片尽可能通过鉴别者网络的识别self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))self.g_loss.append(g_loss)def  train_step(self):train_dataset.shuffle(BUFFER_SIZE)image_batchs = train_dataset.take(self.discriminator_trains)for image_batch in image_batchs:#注意先训练鉴别者网络5回才训练生成者网络一回self.train_discriminator(image_batch)self.train_generator(256)def  train(self, epochs, run_folder):#启动训练流程for  epoch in range(EPOCHS):start = time.time()self.epoch = epochself.train_step()if  self.epoch % 10 == 0:display.clear_output(wait=True)self.sample_images(run_folder) #将生成者构造的图像绘制出来self.save_model(run_folder) #存储两个网络的内部参数print("time for epoc:{} is {} seconds".format(epoch, time.time() - start))def  sample_images(self, run_folder): #绘制生成者构建的图像predictions = self.generator(self.seed)predictions = predictions.numpy()predictions = 0.5 * (predictions + 1)predictions = np.clip(predictions, 0, 1)fig = plt.figure(figsize=(4,4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(predictions[i, :, :, :] )plt.axis('off')plt.savefig('/content/drive/My Drive/WGAN_GP/images/sample{:04d}.png'.format(self.epoch))plt.show()def  save_model(self, run_folder): #保持网络内部参数self.discriminator.save_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.save_weights(os.path.join(run_folder, 'generator.h5'))def  load_model(self, run_folder):self.discriminator.load_weights(os.path.join(run_folder, 'discriminator.h5'))self.generator.load_weights(os.path.join(run_folder, 'generator.h5'))

gan = GAN(z_dim = 100)
gan.train(epochs = EPOCHS, run_folder = ‘/content/drive/My Drive/WGAN_GP/checkpoints’)
···
代码与上节有几个重要区别,第一个区别是在train_discriminator和train_generator函数中,代码将伪造图片对应的标签从0改为-1,这种改动意味着鉴别者网络在对输入图片的真实性进行评估。

因此在训练鉴别者网络时,将真实图片的标签设置为1,将伪造图片的标签设置为-1,意味着算法想训练鉴别者网络,让它给真实图片赋予更高评分,给伪造图片赋予更低评分,生成者网络的目的是使得生成的图片尽可能的获得鉴别者网络的高评分。

训练代码中还有一个要点在于每次训练完鉴别者网络后,需要将网络内部参数的值剪切到位于区间[-0.01,0.01]之间,这种做法目的是让鉴别者网络作为一个函数能满足损失函数公式,问题在于这种做法与将网络变成
image.png
类型的函数牛马不相及。

算法作者提出算法时并不知道如何使鉴别者网络变成给定类型函数,剪切网络内部参数其实是一种权宜之计,是算法作者“试”出来的一种有效做法,就像爱迪生通过海量“遍历”从而找到钨丝作为灯丝那样,上面代码运行后生成者网络生成的马图片质量如图17-10所示:
image.png
来自数据集中的真实图片如下所示:
image.png
从生成图片与数据集图片比较来看,生成图片能准确的把握住马的轮廓形态,皮毛特征,也就是生成者网络非常准确的把握住马的内在关键特征,因此它能学会如何绘制出形象的马图片,网络存在的问题在于,其生成的图片较为模糊,在下一节我们将研究如何进一步改进WGAN网络。

更详细的讲解和代码调试演示过程,请点击链接

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于使用’推土距离‘构建强悍的WGAN的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950024

相关文章

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Linux alias的三种使用场景方式

《Linuxalias的三种使用场景方式》文章介绍了Linux中`alias`命令的三种使用场景:临时别名、用户级别别名和系统级别别名,临时别名仅在当前终端有效,用户级别别名在当前用户下所有终端有效... 目录linux alias三种使用场景一次性适用于当前用户全局生效,所有用户都可调用删除总结Linux

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

Mysql虚拟列的使用场景

《Mysql虚拟列的使用场景》MySQL虚拟列是一种在查询时动态生成的特殊列,它不占用存储空间,可以提高查询效率和数据处理便利性,本文给大家介绍Mysql虚拟列的相关知识,感兴趣的朋友一起看看吧... 目录1. 介绍mysql虚拟列1.1 定义和作用1.2 虚拟列与普通列的区别2. MySQL虚拟列的类型2

使用MongoDB进行数据存储的操作流程

《使用MongoDB进行数据存储的操作流程》在现代应用开发中,数据存储是一个至关重要的部分,随着数据量的增大和复杂性的增加,传统的关系型数据库有时难以应对高并发和大数据量的处理需求,MongoDB作为... 目录什么是MongoDB?MongoDB的优势使用MongoDB进行数据存储1. 安装MongoDB

关于@MapperScan和@ComponentScan的使用问题

《关于@MapperScan和@ComponentScan的使用问题》文章介绍了在使用`@MapperScan`和`@ComponentScan`时可能会遇到的包扫描冲突问题,并提供了解决方法,同时,... 目录@MapperScan和@ComponentScan的使用问题报错如下原因解决办法课外拓展总结@

mysql数据库分区的使用

《mysql数据库分区的使用》MySQL分区技术通过将大表分割成多个较小片段,提高查询性能、管理效率和数据存储效率,本文就来介绍一下mysql数据库分区的使用,感兴趣的可以了解一下... 目录【一】分区的基本概念【1】物理存储与逻辑分割【2】查询性能提升【3】数据管理与维护【4】扩展性与并行处理【二】分区的

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超

Linux使用fdisk进行磁盘的相关操作

《Linux使用fdisk进行磁盘的相关操作》fdisk命令是Linux中用于管理磁盘分区的强大文本实用程序,这篇文章主要为大家详细介绍了如何使用fdisk进行磁盘的相关操作,需要的可以了解下... 目录简介基本语法示例用法列出所有分区查看指定磁盘的区分管理指定的磁盘进入交互式模式创建一个新的分区删除一个存