使用变分编解码器实现自动图像生成

2024-04-30 22:08

本文主要是介绍使用变分编解码器实现自动图像生成,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度学习不仅仅在擅长于从现有数据中发现规律,而且它能主动运用规律创造出现实世界没有的实例来。例如给网络输入大量的人脸图片,让它识别人脸特征,然后我们可以指导网络创建出现实世界中不存在的人脸图像,把深度学习应用在创造性生成上是当前AI领域非常热门的应用。

从本节开始,我们将接触神经网络在图像生成方面的应用。有两种专门构建的网络在图像生成上能实现良好效果,一种网络叫变分编解码器,另一种叫生成型对抗性网络。这两种网络不仅仅能有与图片生成,还能用于音乐,声音,以及文本生成,但是在图像生成的效果上表现最好,因此接下来我们看看如何构建相应网络实现生成功能。

图像生成的关键思想是,使用网络构造一个向量空间,空间中每一个向量都可以映射成一张真实图片。在网络中有一个模块,读入该向量后,能够经过一系列运算把向量转换成一张图片所对应的二维向量,这个模块在编解码器网络里称为解码器。

编解码器网络的运行流畅如下:

屏幕快照 2019-02-15 下午5.44.51.png

首先我们把大量图片输入到网络中,网络识别图片并抽取图片中蕴含的规律,它把这些规律进行编码,以向量的形式存储,向量的长度越大,它就能存储越多的图片信息。接着网络的解码器模块解读编码向量,由于向量存储的是所有图片共同展现的人脸特征,而不是某个具体人的人脸特征,因此解码器解读编码向量后,就能根据向量蕴含的人脸特征进行绘图,最终构造出原来训练图片里没有的人脸图案,但这个人脸图案的特征与训练图片里面的人脸特征有相关性。

其实我们在前面章节已经接触过特征向量。在前面讲解单词向量时,所谓的单词向量就是一种特征向量。向量空间中,某个方向,也就是向量里面的某些分量可能记录了训练数据的某一方面的特征,对应人脸图片来说,向量可能有一部分分量用来记录笑容特征,某些分量可能记录了眼睛特征,某些分量可能记录了头发特征,所有这些特征综合起来就可能形成一张人脸。由于解码器能够识别向量中不同分量代表的信息,因此它把向量拆分解读之后,再按照向量分量表达的信息来绘制像素点,最终就可以完成一张人脸图片的绘制。

编解码器网络发明与2013和2014年,它能够把高维数据所展现的特征编码成低维向量,然后再把低维向量转换为原来数据所表示的高维向量。但这种还原并非原封不动的还原,而是把低维向量编码的信息展现出来。编解码网络有点像压缩和解压,把解码器模块把输入数据转变成另一种数据量较小的数据格式,而解码器再把该数据格式还原成输入数据,然而编解码器网络可不是简单的进行数据压缩和解压。

屏幕快照 2019-02-16 下午4.55.15.png

如上图,编解码器网络本质上是在学习输入图片像素点的统计信息,知道了像素点在统计上的分布规律后,它再按照相应的分布规律产生像素点,于是产生的图片与输入图片很像,但因为是根据统计规律随机产生的,因此生成的图片会产生某些变异。当我们把大量图片输入网络进行学习时,网络的编码器统计图片像素点变化的均值和方差,以及变化特征,这些特征编码成中间向量格式,然后解码器读取该向量,用随机方法把还原图片像素点的变化规律。

接下来我们看看代码的实现:

from keras.models import Model
from keras import layers
import numpy as np
import keras#输入图片为28*28的灰度图
img_shape = (28, 28, 1)
batch_size = 16
#将输入图片编码为只含有2个分量的向量
latent_dim = 2input_img = keras.Input(shape = img_shape)
#设计编码器部分
x = layers.Conv2D(32, 3, padding = 'same', activation = 'relu')(input_img)
x = layers.Conv2D(64, 3, padding = 'same', activation = 'relu', strides = (2,2))(x)
x = layers.Conv2D(64, 3, padding = 'same', activation = 'relu')(x)
x = layers.Conv2D(64, 3, padding = 'same', activation = 'relu')(x)shape = K.int_shape(x)
#把x压扁成一维向量
x = layers.Flatten()(x)
x = layers.Dense(32, activation = 'relu')(x)
#统计输入图片像素点统计规律上的均值
z_mean = layers.Dense(latent_dim)(x)
#统计输入图片像素点统计规律上的方差
z_log_var = layers.Dense(latent_dim)(x)'''
均值和方差决定了像素点的变化规律,在统计上发现,大量的事物在数据上的变化都遵守正太分布,一旦掌握
了其数值变化的方差和均值,我们就掌握了它变化的规律。在实现解码器时,我们也认为输入图片的像素点
同样符合正太分布,下面函数根据上面得到的均值和方差构造正太分布,然后从这个分布中进行抽样构成要
还原图片的像素点
'''
def  sampling(args):z_mean, z_log_var = args#构造一个随机值,然后使用它到给定正太分布中生成一个结果,这类似于丢一个骰子然后看点数epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),mean = 0, stddev = 1.)return  z_mean + K.exp(z_log_var) * epsilonz = layers.Lambda(sampling)([z_mean, z_log_var])

上面是编码器的实现,从这里我们看到,深度学习其本质并没有什么神奇的魔力,它本质是对大量的输入数据进行数理统计,由此就能掌握事物的变化规律。我们再看看解码器的实现:

#解码过程是对编码过程的逆运算
decoder_input = layers.Input(K.int_shape(z)[1:])
x = layers.Dense(np.prod(shape[1:]), activation = 'relu')(decoder_input)
#恢复为向量压扁前的格式
x = layers.Reshape(shape[1:])(x)
#对编码器的卷积运输进行逆操作
x = layers.Conv2DTranspose(32, 3, padding = 'same', activation = 'relu',strides = (2, 2))(x)
x = layers.Conv2D(1, 3, padding = 'same', activation = 'sigmoid')(x)
decoder = Model(decoder_input, x)z_decoded = decoder(z)

接下来我们设置网络的损失函数:

'''
我们定义网络的损失框架没有提供,因此我们自己动手写
'''
class  CustomVariationLayer(keras.layers.Layer):def  vae_loss(self, x, z_decoded):x = K.flatten(x)z_decoded = K.flatten(z_decoded)#计算生成二维数组与输入图片二维数组对应元素的差方和xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)#计算网络生成像素点统计分布与输入图片像素点变化分布的差异k1_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis = -1)return K.mean(xent_loss + k1_loss)def  call(self, inputs):x = inputs[0]z_decoded = inputs[1]loss = self.vae_loss(x, z_decoded)self.add_loss(loss, inputs = inputs)return xy = CustomVariationLayer()([input_img, z_decoded])

损失函数要计算两部分,一部分是网络解码得到的二维数组与输入图片二维数组对应元素的差方和,第二是网络构造的二维数组,其元素变化规律与输入图片元素变化规律在统计上的差异,也就是我们希望网络生成的二维数组,其元素变化在统计上的均值与方差和输入图片像素点在统计上的均值和方差要尽可能的小。这里涉及到数理统计方面的知识,不了解可以直接忽略掉。

最后我们看看网络的训练过程:

from keras.datasets import mnistvae = Model(input_img, y)
vae.compile(optimizer = 'rmsprop', loss = None)
vae.summary()(x_train, _), (x_test, y_test) = mnist.load_data()x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))vae.fit(x = x_train, y = None,shuffle = True,epochs = 10,batch_size = batch_size,validation_data = (x_test, None))

训练后,我们看看网络对输入图片的还原效果:

import matplotlib.pyp![
](https://upload-images.jianshu.io/upload_images/2849961-5d9959da5b4d37e8.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240)
lot as plt
from scipy.stats import norm#一次呈现15*15个数字
n = 15
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))for i, yi in enumerate(grid_x):for j, xi in enumerate(grid_y):z_sample = np.array([[xi, yi]])z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)x_decoded = decoder.predict(z_sample, batch_size = batch_size)digit = x_decoded[0].reshape(digit_size, digit_size)figure[i * digit_size: (i+1) * digit_size,j * digit_size: (j+1) * digit_size] = digitplt.figure(figsize = (10, 10))
plt.imshow(figure, cmap = 'Greys_r')
plt.show()

上面代码运行后可以看到,网络学习了数字图片中像素点的分布规律后,按照规律构造还原会相应的数字图片,还原的图片与输入图片大致相同,但在细节上有些许差异:
1.png

本节的内容比较抽象,不好理解。因为它用到了很多数学知识,没有深厚的数学功底你很难掌握本节内容,这也是现在程序员很难转行到人工智能,特别是深度学习领域的根本原因,因为他们具备的是工程思维,而人工智能要求你具备深厚的数学基础以及科学研究思维,如果你理解不了本节内容不要紧,只要把代码敲一遍,看看结果,具有一个感性认识也就可以了。

更多内容,请点击进入csdn学院

更多技术信息,包括操作系统,编译器,面试算法,机器学习,人工智能,请关照我的公众号:
这里写图片描述

这篇关于使用变分编解码器实现自动图像生成的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/950047

相关文章

详解Vue如何使用xlsx库导出Excel文件

《详解Vue如何使用xlsx库导出Excel文件》第三方库xlsx提供了强大的功能来处理Excel文件,它可以简化导出Excel文件这个过程,本文将为大家详细介绍一下它的具体使用,需要的小伙伴可以了解... 目录1. 安装依赖2. 创建vue组件3. 解释代码在Vue.js项目中导出Excel文件,使用第三

Linux alias的三种使用场景方式

《Linuxalias的三种使用场景方式》文章介绍了Linux中`alias`命令的三种使用场景:临时别名、用户级别别名和系统级别别名,临时别名仅在当前终端有效,用户级别别名在当前用户下所有终端有效... 目录linux alias三种使用场景一次性适用于当前用户全局生效,所有用户都可调用删除总结Linux

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python脚本实现自动删除C盘临时文件夹

《Python脚本实现自动删除C盘临时文件夹》在日常使用电脑的过程中,临时文件夹往往会积累大量的无用数据,占用宝贵的磁盘空间,下面我们就来看看Python如何通过脚本实现自动删除C盘临时文件夹吧... 目录一、准备工作二、python脚本编写三、脚本解析四、运行脚本五、案例演示六、注意事项七、总结在日常使用

Java实现Excel与HTML互转

《Java实现Excel与HTML互转》Excel是一种电子表格格式,而HTM则是一种用于创建网页的标记语言,虽然两者在用途上存在差异,但有时我们需要将数据从一种格式转换为另一种格式,下面我们就来看看... Excel是一种电子表格格式,广泛用于数据处理和分析,而HTM则是一种用于创建网页的标记语言。虽然两

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Java中Springboot集成Kafka实现消息发送和接收功能

《Java中Springboot集成Kafka实现消息发送和接收功能》Kafka是一个高吞吐量的分布式发布-订阅消息系统,主要用于处理大规模数据流,它由生产者、消费者、主题、分区和代理等组件构成,Ka... 目录一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者一、Kaf

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

Mysql虚拟列的使用场景

《Mysql虚拟列的使用场景》MySQL虚拟列是一种在查询时动态生成的特殊列,它不占用存储空间,可以提高查询效率和数据处理便利性,本文给大家介绍Mysql虚拟列的相关知识,感兴趣的朋友一起看看吧... 目录1. 介绍mysql虚拟列1.1 定义和作用1.2 虚拟列与普通列的区别2. MySQL虚拟列的类型2

使用MongoDB进行数据存储的操作流程

《使用MongoDB进行数据存储的操作流程》在现代应用开发中,数据存储是一个至关重要的部分,随着数据量的增大和复杂性的增加,传统的关系型数据库有时难以应对高并发和大数据量的处理需求,MongoDB作为... 目录什么是MongoDB?MongoDB的优势使用MongoDB进行数据存储1. 安装MongoDB