Fashion MNIST 图片重建与生成(VAE)

2023-10-30 16:50

本文主要是介绍Fashion MNIST 图片重建与生成(VAE),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前面只能利用AE来重建图片,不是生成图片。这里利用VAE模型完成图片的重建与生成。

一、数据集的加载以及预处理

# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

和AE一样这里只需要数据集的图片数据x,不需要标签y

二、网络模型的构建

输入为 Fashion MNIST 图片向量,经过 3 个全连接层后得到隐向量𝐳的均值与方差分别用两
个输出节点数为 20 的全连接层表示, FC2 的 20 个输出节点表示 20 个特征分布的均值向量
FC3 的 20 个输出节点表示 20 个特征分布的取log后的方差向量通过Reparameterization Trick 采样获得长度为 20 的隐向量𝐳,并通过 FC4 和 FC5 重建出样本图片

class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_var

Encoder 的输入先通过共享层 FC1,然后分别通过 FC2 与 FC3 网络,获得隐向量分布的均值向量与方差的log向量值Decoder 接受采样后的隐向量𝐳,并解码为图片输出。
 

在 VAE 的前向计算过程中,首先通过编码器获得输入的隐向量𝐳的分布,然后利用Reparameterization Trick 实现的 reparameterize 函数采样获得隐向量𝐳,Reparameterize 函数接受均值与方差参数,并从正态分布𝒩(0, 𝐼)中采样获得𝜀,通过z = 𝜇 + 𝜎 ⊙ 𝜀方式返回采样隐向量, 最后通过解码器即可恢复重建的图片向量。 

Reparameterization Trick原因:编码器输出正态分布的均值𝜇和方差𝜎2,解码器的输入采样自𝒩(𝜇, 𝜎2)。由于采样操作的存在,导致梯度传播是不连续的,无法通过梯度下降算法端到端式地训练 VAE 网络。

它通过z = u + \sigma \odot \varepsilon方式采样隐变量z,\frac{\partial z}{\partial u}\frac{\partial z}{\partial \sigma }是连续可导的,从而将梯度传播连接起来

三、网络装配与训练

网络模型建立以后,给网络选择一定的优化器,设置学习率,就可以进行模型训练。

model = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))

在VAE模型中代价函数:\pounds (\theta ,\phi ) = -\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) + \mathbb{E}_{z-q}[logp_{\theta }(x|z)]

 当𝑞 (z |𝑥)和𝑝(z )都假设为正态分布时:\mathbb{D_{KL}}(q_{\phi }(z|x)||p(z)) = log(\frac{\sigma _{2}}{\sigma _{1}}) + \frac{\sigma _{1}^{2} + (u_{1}-u_{2})^2} {2\sigma _{2}^{2}}-\frac{1}{2}

当𝑞 ( |𝑥)为正态分布𝒩(𝜇1, 𝜎1), 𝑝( )为正态分布𝒩(0,1)时,即𝜇2 = 0, 𝜎2 =1,此时
\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) = -log\sigma_{1} + 0.5\sigma _{1}^{2} + 0.5u_{1}^{2} - 0.5

 而 max\mathbb{E}_{zq}[logp_{\theta }(x|z)],该项可以基于自编码器中的重建误差函数实现

所以,损失函数 = 自编码器重建误差函数 + KL散度

四、测试

图片生成只利用到解码器网络,首先从先验分布𝒩(0, 𝐼)中采样获得隐向量,再通过解码器获得图片向量,最后 Reshape 为图片矩阵。

      # 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

结果:

图片重建的效果是要略好于图片生成的,这也说明了图片生成是更为复杂的任务, VAE 模型虽然具有图片生成的能力,但是生成的效果仍然不够优秀,人眼还是能够较轻松地分辨出机器生成的和真实的图片样本

五、程序

# -*- codeing = utf-8 -*-
# @Time : 12:03
# @Author:Paranipd
# @File : VAE_test.py
# @Software:PyCharmimport os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import datasets, Sequential, layers, metrics, optimizers, lossestf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2')def save_image(imgs, name):# 创建 280x280 大小图片阵列new_im = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):  # 10 行图片阵列for j in range(0, 280, 28):  # 10 列图片阵列im = imgs[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))  # 写入对应位置index += 1# 保存图片阵列new_im.save(name)h_dim = 20
z_dim = 10
batches = 512# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_varmodel = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))# 评估# 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

这篇关于Fashion MNIST 图片重建与生成(VAE)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/309284

相关文章

C#实现添加/替换/提取或删除Excel中的图片

《C#实现添加/替换/提取或删除Excel中的图片》在Excel中插入与数据相关的图片,能将关键数据或信息以更直观的方式呈现出来,使文档更加美观,下面我们来看看如何在C#中实现添加/替换/提取或删除E... 在Excandroidel中插入与数据相关的图片,能将关键数据或信息以更直观的方式呈现出来,使文档更

MybatisGenerator文件生成不出对应文件的问题

《MybatisGenerator文件生成不出对应文件的问题》本文介绍了使用MybatisGenerator生成文件时遇到的问题及解决方法,主要步骤包括检查目标表是否存在、是否能连接到数据库、配置生成... 目录MyBATisGenerator 文件生成不出对应文件先在项目结构里引入“targetProje

Python使用qrcode库实现生成二维码的操作指南

《Python使用qrcode库实现生成二维码的操作指南》二维码是一种广泛使用的二维条码,因其高效的数据存储能力和易于扫描的特点,广泛应用于支付、身份验证、营销推广等领域,Pythonqrcode库是... 目录一、安装 python qrcode 库二、基本使用方法1. 生成简单二维码2. 生成带 Log

C#中图片如何自适应pictureBox大小

《C#中图片如何自适应pictureBox大小》文章描述了如何在C#中实现图片自适应pictureBox大小,并展示修改前后的效果,修改步骤包括两步,作者分享了个人经验,希望对大家有所帮助... 目录C#图片自适应pictureBox大小编程修改步骤总结C#图片自适应pictureBox大小上图中“z轴

使用Python将长图片分割为若干张小图片

《使用Python将长图片分割为若干张小图片》这篇文章主要为大家详细介绍了如何使用Python将长图片分割为若干张小图片,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. python需求的任务2. Python代码的实现3. 代码修改的位置4. 运行结果1. Python需求

Python使用Pandas库将Excel数据叠加生成新DataFrame的操作指南

《Python使用Pandas库将Excel数据叠加生成新DataFrame的操作指南》在日常数据处理工作中,我们经常需要将不同Excel文档中的数据整合到一个新的DataFrame中,以便进行进一步... 目录一、准备工作二、读取Excel文件三、数据叠加四、处理重复数据(可选)五、保存新DataFram

SpringBoot生成和操作PDF的代码详解

《SpringBoot生成和操作PDF的代码详解》本文主要介绍了在SpringBoot项目下,通过代码和操作步骤,详细的介绍了如何操作PDF,希望可以帮助到准备通过JAVA操作PDF的你,项目框架用的... 目录本文简介PDF文件简介代码实现PDF操作基于PDF模板生成,并下载完全基于代码生成,并保存合并P

详解Java中如何使用JFreeChart生成甘特图

《详解Java中如何使用JFreeChart生成甘特图》甘特图是一种流行的项目管理工具,用于显示项目的进度和任务分配,在Java开发中,JFreeChart是一个强大的开源图表库,能够生成各种类型的图... 目录引言一、JFreeChart简介二、准备工作三、创建甘特图1. 定义数据集2. 创建甘特图3.

使用 Python 和 LabelMe 实现图片验证码的自动标注功能

《使用Python和LabelMe实现图片验证码的自动标注功能》文章介绍了如何使用Python和LabelMe自动标注图片验证码,主要步骤包括图像预处理、OCR识别和生成标注文件,通过结合Pa... 目录使用 python 和 LabelMe 实现图片验证码的自动标注环境准备必备工具安装依赖实现自动标注核心

Java操作xls替换文本或图片的功能实现

《Java操作xls替换文本或图片的功能实现》这篇文章主要给大家介绍了关于Java操作xls替换文本或图片功能实现的相关资料,文中通过示例代码讲解了文件上传、文件处理和Excel文件生成,需要的朋友可... 目录准备xls模板文件:template.xls准备需要替换的图片和数据功能实现包声明与导入类声明与