Fashion MNIST 图片重建与生成(VAE)

2023-10-30 16:50

本文主要是介绍Fashion MNIST 图片重建与生成(VAE),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前面只能利用AE来重建图片,不是生成图片。这里利用VAE模型完成图片的重建与生成。

一、数据集的加载以及预处理

# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

和AE一样这里只需要数据集的图片数据x,不需要标签y

二、网络模型的构建

输入为 Fashion MNIST 图片向量,经过 3 个全连接层后得到隐向量𝐳的均值与方差分别用两
个输出节点数为 20 的全连接层表示, FC2 的 20 个输出节点表示 20 个特征分布的均值向量
FC3 的 20 个输出节点表示 20 个特征分布的取log后的方差向量通过Reparameterization Trick 采样获得长度为 20 的隐向量𝐳,并通过 FC4 和 FC5 重建出样本图片

class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_var

Encoder 的输入先通过共享层 FC1,然后分别通过 FC2 与 FC3 网络,获得隐向量分布的均值向量与方差的log向量值Decoder 接受采样后的隐向量𝐳,并解码为图片输出。
 

在 VAE 的前向计算过程中,首先通过编码器获得输入的隐向量𝐳的分布,然后利用Reparameterization Trick 实现的 reparameterize 函数采样获得隐向量𝐳,Reparameterize 函数接受均值与方差参数,并从正态分布𝒩(0, 𝐼)中采样获得𝜀,通过z = 𝜇 + 𝜎 ⊙ 𝜀方式返回采样隐向量, 最后通过解码器即可恢复重建的图片向量。 

Reparameterization Trick原因:编码器输出正态分布的均值𝜇和方差𝜎2,解码器的输入采样自𝒩(𝜇, 𝜎2)。由于采样操作的存在,导致梯度传播是不连续的,无法通过梯度下降算法端到端式地训练 VAE 网络。

它通过z = u + \sigma \odot \varepsilon方式采样隐变量z,\frac{\partial z}{\partial u}\frac{\partial z}{\partial \sigma }是连续可导的,从而将梯度传播连接起来

三、网络装配与训练

网络模型建立以后,给网络选择一定的优化器,设置学习率,就可以进行模型训练。

model = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))

在VAE模型中代价函数:\pounds (\theta ,\phi ) = -\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) + \mathbb{E}_{z-q}[logp_{\theta }(x|z)]

 当𝑞 (z |𝑥)和𝑝(z )都假设为正态分布时:\mathbb{D_{KL}}(q_{\phi }(z|x)||p(z)) = log(\frac{\sigma _{2}}{\sigma _{1}}) + \frac{\sigma _{1}^{2} + (u_{1}-u_{2})^2} {2\sigma _{2}^{2}}-\frac{1}{2}

当𝑞 ( |𝑥)为正态分布𝒩(𝜇1, 𝜎1), 𝑝( )为正态分布𝒩(0,1)时,即𝜇2 = 0, 𝜎2 =1,此时
\mathbb{D}_{KL}(q_{\phi }(z|x)||p(z)) = -log\sigma_{1} + 0.5\sigma _{1}^{2} + 0.5u_{1}^{2} - 0.5

 而 max\mathbb{E}_{zq}[logp_{\theta }(x|z)],该项可以基于自编码器中的重建误差函数实现

所以,损失函数 = 自编码器重建误差函数 + KL散度

四、测试

图片生成只利用到解码器网络,首先从先验分布𝒩(0, 𝐼)中采样获得隐向量,再通过解码器获得图片向量,最后 Reshape 为图片矩阵。

      # 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

结果:

图片重建的效果是要略好于图片生成的,这也说明了图片生成是更为复杂的任务, VAE 模型虽然具有图片生成的能力,但是生成的效果仍然不够优秀,人眼还是能够较轻松地分辨出机器生成的和真实的图片样本

五、程序

# -*- codeing = utf-8 -*-
# @Time : 12:03
# @Author:Paranipd
# @File : VAE_test.py
# @Software:PyCharmimport os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras import datasets, Sequential, layers, metrics, optimizers, lossestf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2')def save_image(imgs, name):# 创建 280x280 大小图片阵列new_im = Image.new('L', (280, 280))index = 0for i in range(0, 280, 28):  # 10 行图片阵列for j in range(0, 280, 28):  # 10 列图片阵列im = imgs[index]im = Image.fromarray(im, mode='L')new_im.paste(im, (i, j))  # 写入对应位置index += 1# 保存图片阵列new_im.save(name)h_dim = 20
z_dim = 10
batches = 512# 加载Fashion MNIST数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# 归一化
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
# 只需要通过图片数据即可构建数据集对象,不需要标签
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batches * 5).batch(batches)
# 构建测试集对象
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batches)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)class VAE(keras.Model):def __init__(self):super(VAE, self).__init__()# Encodersself.fc1 = layers.Dense(128, activation=tf.nn.relu)self.fc2 = layers.Dense(z_dim)  # 均值self.fc3 = layers.Dense(z_dim)  # 方差# Decodersself.f4 = layers.Dense(128, activation=tf.nn.relu)self.f5 = layers.Dense(784)def encoder(self, x):h = self.fc1(x)# 均值mu = self.fc2(h)# 方差log_var = self.fc3(h)return mu, log_vardef decoder(self, z):out = self.f4(z)out = self.f5(out)return out# 参数化def reparameterize(self, mu, log_var):esp = tf.random.normal(log_var.shape)std = tf.exp(log_var * 0.5)z = mu + std * espreturn zdef call(self, inputs, training):# [b,784] -> [b, z_dim],[b,z_dim]mu, log_var = self.encoder(inputs)# reparameterization tickz = self.reparameterize(mu, log_var)# --> [b, 784]x_hat = self.decoder(z)return x_hat, mu, log_varmodel = VAE()
model.build(input_shape=(4, 784))
model.summary()optimizer = optimizers.Adam(lr=1e-3)for epoch in range(100):for step, x in enumerate(train_db):# [b,28,28] -> [b,784]x = tf.reshape(x, [-1, 784])# 构建梯度记录器with tf.GradientTape() as tape:# 前向计算获得重建的图片x_rec_logits, mu, log_var = model(x)   # call函数返回值# x 与 重建的 x :重建图片与输入之间的损失函数rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]# compute kl divergence散度  (mu, var) ~ N (0, 1) 并且p(z) ~ (0, 1)kl_div = -0.5*(log_var+1-mu**2-tf.exp(log_var))kl_div = tf.reduce_sum(kl_div) / x.shape[0]loss = rec_loss + 1.*kl_div  # 损失函数 = 自编码器重建误差函数 + KL散度grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print(epoch, step, 'kl_div:', float(kl_div),'rec loss:',float(rec_loss))# 评估# 生成图片,从正太分布随机采样zz = tf.random.normal((batches, z_dim))logits = model.decoder(z)x_hat = tf.sigmoid(logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)# 重建图片,从测试集中采用图片x = next(iter(test_db))x = tf.reshape(x, [-1, 784])x_hat_logits, _, _ = model(x)  # call返回值x_hat = tf.sigmoid(x_hat_logits)x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.x_hat = x_hat.astype(np.uint8)save_image(x_hat, 'vae_images/rec_epoch%d.png' % epoch)

这篇关于Fashion MNIST 图片重建与生成(VAE)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/309284

相关文章

基于Python实现一个图片拆分工具

《基于Python实现一个图片拆分工具》这篇文章主要为大家详细介绍了如何基于Python实现一个图片拆分工具,可以根据需要的行数和列数进行拆分,感兴趣的小伙伴可以跟随小编一起学习一下... 简单介绍先自己选择输入的图片,默认是输出到项目文件夹中,可以自己选择其他的文件夹,选择需要拆分的行数和列数,可以通过

利用Python脚本实现批量将图片转换为WebP格式

《利用Python脚本实现批量将图片转换为WebP格式》Python语言的简洁语法和库支持使其成为图像处理的理想选择,本文将介绍如何利用Python实现批量将图片转换为WebP格式的脚本,WebP作为... 目录简介1. python在图像处理中的应用2. WebP格式的原理和优势2.1 WebP格式与传统

基于 HTML5 Canvas 实现图片旋转与下载功能(完整代码展示)

《基于HTML5Canvas实现图片旋转与下载功能(完整代码展示)》本文将深入剖析一段基于HTML5Canvas的代码,该代码实现了图片的旋转(90度和180度)以及旋转后图片的下载... 目录一、引言二、html 结构分析三、css 样式分析四、JavaScript 功能实现一、引言在 Web 开发中,

Python如何去除图片干扰代码示例

《Python如何去除图片干扰代码示例》图片降噪是一个广泛应用于图像处理的技术,可以提高图像质量和相关应用的效果,:本文主要介绍Python如何去除图片干扰的相关资料,文中通过代码介绍的非常详细,... 目录一、噪声去除1. 高斯噪声(像素值正态分布扰动)2. 椒盐噪声(随机黑白像素点)3. 复杂噪声(如伪

Python中图片与PDF识别文本(OCR)的全面指南

《Python中图片与PDF识别文本(OCR)的全面指南》在数据爆炸时代,80%的企业数据以非结构化形式存在,其中PDF和图像是最主要的载体,本文将深入探索Python中OCR技术如何将这些数字纸张转... 目录一、OCR技术核心原理二、python图像识别四大工具库1. Pytesseract - 经典O

Python实现精准提取 PDF中的文本,表格与图片

《Python实现精准提取PDF中的文本,表格与图片》在实际的系统开发中,处理PDF文件不仅限于读取整页文本,还有提取文档中的表格数据,图片或特定区域的内容,下面我们来看看如何使用Python实... 目录安装 python 库提取 PDF 文本内容:获取整页文本与指定区域内容获取页面上的所有文本内容获取

Python基于微信OCR引擎实现高效图片文字识别

《Python基于微信OCR引擎实现高效图片文字识别》这篇文章主要为大家详细介绍了一款基于微信OCR引擎的图片文字识别桌面应用开发全过程,可以实现从图片拖拽识别到文字提取,感兴趣的小伙伴可以跟随小编一... 目录一、项目概述1.1 开发背景1.2 技术选型1.3 核心优势二、功能详解2.1 核心功能模块2.

Go语言如何判断两张图片的相似度

《Go语言如何判断两张图片的相似度》这篇文章主要为大家详细介绍了Go语言如何中实现判断两张图片的相似度的两种方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 在介绍技术细节前,我们先来看看图片对比在哪些场景下可以用得到:图片去重:自动删除重复图片,为存储空间"瘦身"。想象你是一个

使用Python实现base64字符串与图片互转的详细步骤

《使用Python实现base64字符串与图片互转的详细步骤》要将一个Base64编码的字符串转换为图片文件并保存下来,可以使用Python的base64模块来实现,这一过程包括解码Base64字符串... 目录1. 图片编码为 Base64 字符串2. Base64 字符串解码为图片文件3. 示例使用注意

Python实现自动化Word文档样式复制与内容生成

《Python实现自动化Word文档样式复制与内容生成》在办公自动化领域,高效处理Word文档的样式和内容复制是一个常见需求,本文将展示如何利用Python的python-docx库实现... 目录一、为什么需要自动化 Word 文档处理二、核心功能实现:样式与表格的深度复制1. 表格复制(含样式与内容)2