Fashion MNIST 图片重建与生成(VAE)

2023-10-30 16:50

本文主要是介绍Fashion MNIST 图片重建与生成(VAE),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前面只能利用AE来重建图片,不是生成图片。这里利用VAE模型完成图片的重建与生成。

一、数据集的加载以及预处理

# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

和AE一样这里只需要数据集的图片数据x,不需要标签y

二、网络模型的构建

输入为 Fashion MNIST 图片向量,经过 3 个全连接层后得到隐向量𝐳的均值与方差分别用两
个输出节点数为 20 的全连接层表示, FC2 的 20 个输出节点表示 20 个特征分布的均值向量
FC3 的 20 个输出节点表示 20 个特征分布的取log后的方差向量通过Reparameterization Trick 采样获得长度为 20 的隐向量𝐳,并通过 FC4 和 FC5 重建出样本图片

class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_var

Encoder 的输入先通过共享层 FC1,然后分别通过 FC2 与 FC3 网络,获得隐向量分布的均值向量与方差的log向量值Decoder 接受采样后的隐向量𝐳,并解码为图片输出。
 

在 VAE 的前向计算过程中,首先通过编码器获得输入的隐向量𝐳的分布,然后利用Reparameterization Trick 实现的 reparameterize 函数采样获得隐向量𝐳,Reparameterize 函数接受均值与方差参数,并从正态分布𝒩(0, 𝐼)中采样获得𝜀,通过z = 𝜇 + 𝜎 ⊙ 𝜀方式返回采样隐向量, 最后通过解码器即可恢复重建的图片向量。 

Reparameterization Trick原因:编码器输出正态分布的均值𝜇和方差𝜎2,解码器的输入采样自𝒩(𝜇, 𝜎2)。由于采样操作的存在,导致梯度传播是不连续的,无法通过梯度下降算法端到端式地训练 VAE 网络。

它通过z = u + \sigma \odot \varepsilon方式采样隐变量z,\frac{\partial z}{\partial u}\frac{\partial z}{\partial \sigma }是连续可导的,从而将梯度传播连接起来

三、网络装配与训练

网络模型建立以后,给网络选择一定的优化器,设置学习率,就可以进行模型训练。

model = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))

在VAE模型中代价函数:\pounds (\theta ,\phi ) = -\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) + \mathbb{E}_{z-q}[logp_{\theta }(x|z)]

 当𝑞 (z |𝑥)和𝑝(z )都假设为正态分布时:\mathbb{D_{KL}}(q_{\phi }(z|x)||p(z)) = log(\frac{\sigma _{2}}{\sigma _{1}}) + \frac{\sigma _{1}^{2} + (u_{1}-u_{2})^2} {2\sigma _{2}^{2}}-\frac{1}{2}

当𝑞 ( |𝑥)为正态分布𝒩(𝜇1, 𝜎1), 𝑝( )为正态分布𝒩(0,1)时,即𝜇2 = 0, 𝜎2 =1,此时
\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) = -log\sigma_{1} + 0.5\sigma _{1}^{2} + 0.5u_{1}^{2} - 0.5

 而 max\mathbb{E}_{zq}[logp_{\theta }(x|z)],该项可以基于自编码器中的重建误差函数实现

所以,损失函数 = 自编码器重建误差函数 + KL散度

四、测试

图片生成只利用到解码器网络,首先从先验分布𝒩(0, 𝐼)中采样获得隐向量,再通过解码器获得图片向量,最后 Reshape 为图片矩阵。

      # 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

结果:

图片重建的效果是要略好于图片生成的,这也说明了图片生成是更为复杂的任务, VAE 模型虽然具有图片生成的能力,但是生成的效果仍然不够优秀,人眼还是能够较轻松地分辨出机器生成的和真实的图片样本

五、程序

# -*- codeing = utf-8 -*-
# @Time : 12:03
# @Author:Paranipd
# @File : VAE_test.py
# @Software:PyCharmimport os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import datasets, Sequential, layers, metrics, optimizers, lossestf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2')def save_image(imgs, name):# 创建 280x280 大小图片阵列new_im = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):  # 10 行图片阵列for j in range(0, 280, 28):  # 10 列图片阵列im = imgs[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))  # 写入对应位置index += 1# 保存图片阵列new_im.save(name)h_dim = 20
z_dim = 10
batches = 512# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_varmodel = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))# 评估# 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

这篇关于Fashion MNIST 图片重建与生成(VAE)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/309284

相关文章

Python使用PIL库将PNG图片转换为ICO图标的示例代码

《Python使用PIL库将PNG图片转换为ICO图标的示例代码》在软件开发和网站设计中,ICO图标是一种常用的图像格式,特别适用于应用程序图标、网页收藏夹图标等场景,本文将介绍如何使用Python的... 目录引言准备工作代码解析实践操作结果展示结语引言在软件开发和网站设计中,ICO图标是一种常用的图像

SpringBoot集成图片验证码框架easy-captcha的详细过程

《SpringBoot集成图片验证码框架easy-captcha的详细过程》本文介绍了如何将Easy-Captcha框架集成到SpringBoot项目中,实现图片验证码功能,Easy-Captcha是... 目录SpringBoot集成图片验证码框架easy-captcha一、引言二、依赖三、代码1. Ea

nginx生成自签名SSL证书配置HTTPS的实现

《nginx生成自签名SSL证书配置HTTPS的实现》本文主要介绍在Nginx中生成自签名SSL证书并配置HTTPS,包括安装Nginx、创建证书、配置证书以及测试访问,具有一定的参考价值,感兴趣的可... 目录一、安装nginx二、创建证书三、配置证书并验证四、测试一、安装nginxnginx必须有"-

Java实战之利用POI生成Excel图表

《Java实战之利用POI生成Excel图表》ApachePOI是Java生态中处理Office文档的核心工具,这篇文章主要为大家详细介绍了如何在Excel中创建折线图,柱状图,饼图等常见图表,需要的... 目录一、环境配置与依赖管理二、数据源准备与工作表构建三、图表生成核心步骤1. 折线图(Line Ch

如何使用CSS3实现波浪式图片墙

《如何使用CSS3实现波浪式图片墙》:本文主要介绍了如何使用CSS3的transform属性和动画技巧实现波浪式图片墙,通过设置图片的垂直偏移量,并使用动画使其周期性地改变位置,可以创建出动态且具有波浪效果的图片墙,同时,还强调了响应式设计的重要性,以确保图片墙在不同设备上都能良好显示,详细内容请阅读本文,希望能对你有所帮助...

Python脚本实现图片文件批量命名

《Python脚本实现图片文件批量命名》这篇文章主要为大家详细介绍了一个用python第三方库pillow写的批量处理图片命名的脚本,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录前言源码批量处理图片尺寸脚本源码GUI界面源码打包成.exe可执行文件前言本文介绍一个用python第三方库pi

Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)

《Python爬虫selenium验证之中文识别点选+图片验证码案例(最新推荐)》本文介绍了如何使用Python和Selenium结合ddddocr库实现图片验证码的识别和点击功能,感兴趣的朋友一起看... 目录1.获取图片2.目标识别3.背景坐标识别3.1 ddddocr3.2 打码平台4.坐标点击5.图

浅析如何使用Swagger生成带权限控制的API文档

《浅析如何使用Swagger生成带权限控制的API文档》当涉及到权限控制时,如何生成既安全又详细的API文档就成了一个关键问题,所以这篇文章小编就来和大家好好聊聊如何用Swagger来生成带有... 目录准备工作配置 Swagger权限控制给 API 加上权限注解查看文档注意事项在咱们的开发工作里,API

Python利用PIL进行图片压缩

《Python利用PIL进行图片压缩》有时在发送一些文件如PPT、Word时,由于文件中的图片太大,导致文件也太大,无法发送,所以本文为大家介绍了Python中图片压缩的方法,需要的可以参考下... 有时在发送一些文件如PPT、Word时,由于文件中的图片太大,导致文件也太大,无法发送,所有可以对文件中的图

java获取图片的大小、宽度、高度方式

《java获取图片的大小、宽度、高度方式》文章介绍了如何将File对象转换为MultipartFile对象的过程,并分享了个人经验,希望能为读者提供参考... 目China编程录Java获取图片的大小、宽度、高度File对象(该对象里面是图片)MultipartFile对象(该对象里面是图片)总结java获取图片