Fashion MNIST 图片重建与生成(VAE)

2023-10-30 16:50

本文主要是介绍Fashion MNIST 图片重建与生成(VAE),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前面只能利用AE来重建图片,不是生成图片。这里利用VAE模型完成图片的重建与生成。

一、数据集的加载以及预处理

# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

和AE一样这里只需要数据集的图片数据x,不需要标签y

二、网络模型的构建

输入为 Fashion MNIST 图片向量,经过 3 个全连接层后得到隐向量𝐳的均值与方差分别用两
个输出节点数为 20 的全连接层表示, FC2 的 20 个输出节点表示 20 个特征分布的均值向量
FC3 的 20 个输出节点表示 20 个特征分布的取log后的方差向量通过Reparameterization Trick 采样获得长度为 20 的隐向量𝐳,并通过 FC4 和 FC5 重建出样本图片

class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_var

Encoder 的输入先通过共享层 FC1,然后分别通过 FC2 与 FC3 网络,获得隐向量分布的均值向量与方差的log向量值Decoder 接受采样后的隐向量𝐳,并解码为图片输出。
 

在 VAE 的前向计算过程中,首先通过编码器获得输入的隐向量𝐳的分布,然后利用Reparameterization Trick 实现的 reparameterize 函数采样获得隐向量𝐳,Reparameterize 函数接受均值与方差参数,并从正态分布𝒩(0, 𝐼)中采样获得𝜀,通过z = 𝜇 + 𝜎 ⊙ 𝜀方式返回采样隐向量, 最后通过解码器即可恢复重建的图片向量。 

Reparameterization Trick原因:编码器输出正态分布的均值𝜇和方差𝜎2,解码器的输入采样自𝒩(𝜇, 𝜎2)。由于采样操作的存在,导致梯度传播是不连续的,无法通过梯度下降算法端到端式地训练 VAE 网络。

它通过z = u + \sigma \odot \varepsilon方式采样隐变量z,\frac{\partial z}{\partial u}\frac{\partial z}{\partial \sigma }是连续可导的,从而将梯度传播连接起来

三、网络装配与训练

网络模型建立以后,给网络选择一定的优化器,设置学习率,就可以进行模型训练。

model = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))

在VAE模型中代价函数:\pounds (\theta ,\phi ) = -\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) + \mathbb{E}_{z-q}[logp_{\theta }(x|z)]

 当𝑞 (z |𝑥)和𝑝(z )都假设为正态分布时:\mathbb{D_{KL}}(q_{\phi }(z|x)||p(z)) = log(\frac{\sigma _{2}}{\sigma _{1}}) + \frac{\sigma _{1}^{2} + (u_{1}-u_{2})^2} {2\sigma _{2}^{2}}-\frac{1}{2}

当𝑞 ( |𝑥)为正态分布𝒩(𝜇1, 𝜎1), 𝑝( )为正态分布𝒩(0,1)时,即𝜇2 = 0, 𝜎2 =1,此时
\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) = -log\sigma_{1} + 0.5\sigma _{1}^{2} + 0.5u_{1}^{2} - 0.5

 而 max\mathbb{E}_{zq}[logp_{\theta }(x|z)],该项可以基于自编码器中的重建误差函数实现

所以,损失函数 = 自编码器重建误差函数 + KL散度

四、测试

图片生成只利用到解码器网络,首先从先验分布𝒩(0, 𝐼)中采样获得隐向量,再通过解码器获得图片向量,最后 Reshape 为图片矩阵。

      # 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

结果:

图片重建的效果是要略好于图片生成的,这也说明了图片生成是更为复杂的任务, VAE 模型虽然具有图片生成的能力,但是生成的效果仍然不够优秀,人眼还是能够较轻松地分辨出机器生成的和真实的图片样本

五、程序

# -*- codeing = utf-8 -*-
# @Time : 12:03
# @Author:Paranipd
# @File : VAE_test.py
# @Software:PyCharmimport os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import datasets, Sequential, layers, metrics, optimizers, lossestf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2')def save_image(imgs, name):# 创建 280x280 大小图片阵列new_im = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):  # 10 行图片阵列for j in range(0, 280, 28):  # 10 列图片阵列im = imgs[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))  # 写入对应位置index += 1# 保存图片阵列new_im.save(name)h_dim = 20
z_dim = 10
batches = 512# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_varmodel = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))# 评估# 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

这篇关于Fashion MNIST 图片重建与生成(VAE)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/309284

相关文章

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

Java实现将HTML文件与字符串转换为图片

《Java实现将HTML文件与字符串转换为图片》在Java开发中,我们经常会遇到将HTML内容转换为图片的需求,本文小编就来和大家详细讲讲如何使用FreeSpire.DocforJava库来实现这一功... 目录前言核心实现:html 转图片完整代码场景 1:转换本地 HTML 文件为图片场景 2:转换 H

Java实现在Word文档中添加文本水印和图片水印的操作指南

《Java实现在Word文档中添加文本水印和图片水印的操作指南》在当今数字时代,文档的自动化处理与安全防护变得尤为重要,无论是为了保护版权、推广品牌,还是为了在文档中加入特定的标识,为Word文档添加... 目录引言Spire.Doc for Java:高效Word文档处理的利器代码实战:使用Java为Wo

基于C#实现PDF转图片的详细教程

《基于C#实现PDF转图片的详细教程》在数字化办公场景中,PDF文件的可视化处理需求日益增长,本文将围绕Spire.PDFfor.NET这一工具,详解如何通过C#将PDF转换为JPG、PNG等主流图片... 目录引言一、组件部署二、快速入门:PDF 转图片的核心 C# 代码三、分辨率设置 - 清晰度的决定因

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

使用Python实现无损放大图片功能

《使用Python实现无损放大图片功能》本文介绍了如何使用Python的Pillow库进行无损图片放大,区分了JPEG和PNG格式在放大过程中的特点,并给出了示例代码,JPEG格式可能受压缩影响,需先... 目录一、什么是无损放大?二、实现方法步骤1:读取图片步骤2:无损放大图片步骤3:保存图片三、示php

C#使用Spire.XLS快速生成多表格Excel文件

《C#使用Spire.XLS快速生成多表格Excel文件》在日常开发中,我们经常需要将业务数据导出为结构清晰的Excel文件,本文将手把手教你使用Spire.XLS这个强大的.NET组件,只需几行C#... 目录一、Spire.XLS核心优势清单1.1 性能碾压:从3秒到0.5秒的质变1.2 批量操作的优雅

Python使用python-pptx自动化操作和生成PPT

《Python使用python-pptx自动化操作和生成PPT》这篇文章主要为大家详细介绍了如何使用python-pptx库实现PPT自动化,并提供实用的代码示例和应用场景,感兴趣的小伙伴可以跟随小编... 目录使用python-pptx操作PPT文档安装python-pptx基础概念创建新的PPT文档查看

在ASP.NET项目中如何使用C#生成二维码

《在ASP.NET项目中如何使用C#生成二维码》二维码(QRCode)已广泛应用于网址分享,支付链接等场景,本文将以ASP.NET为示例,演示如何实现输入文本/URL,生成二维码,在线显示与下载的完整... 目录创建前端页面(Index.cshtml)后端二维码生成逻辑(Index.cshtml.cs)总结

Python实现数据可视化图表生成(适合新手入门)

《Python实现数据可视化图表生成(适合新手入门)》在数据科学和数据分析的新时代,高效、直观的数据可视化工具显得尤为重要,下面:本文主要介绍Python实现数据可视化图表生成的相关资料,文中通过... 目录前言为什么需要数据可视化准备工作基本图表绘制折线图柱状图散点图使用Seaborn创建高级图表箱线图热