【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现

2024-04-03 21:44

本文主要是介绍【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

关于

深度学习中,经常面临图片数据量较小的问题,此时,对数据进行增强,显得比较重要。传统的图片增强方法包括剪切,增加噪声,改变对比度等等方法,但是,对于后端任务的性能提升有限。所以,变分自编码器被用来实现深度数据增强。

变分自编码器的主要缺点在于生成图像过于平滑和模糊,图像细节重建不足。

常见的图像增强方法:https://www.tensorflow.org/tutorials/images/data_augmentation

工具

数据集下载地址: CIFAR-10 and CIFAR-100 datasets

方法实现

加载数据和必要的库函数
import tensorflow.compat.v1.keras.backend as K
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import tensorflow_datasets as tfds
import keras
from keras.models import Model
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshapextrain , ytrain = tfds.as_numpy(tfds.load('cifar10',split='train',batch_size=-1,as_supervised=True,))
xtest , ytest = tfds.as_numpy(tfds.load('cifar10',split='test',batch_size=-1,as_supervised=True,))
xtrain = (xtrain.astype('float32'))/255
xtest = (xtest.astype('float32'))/255height=32
width=32
channels=3
print(f"Train Shape: {xtrain.shape},Test Shape: {xtest.shape}")
plt.imshow(xtrain[0])

编码器模型搭建
input_shape=(height,width,channels)
latent_dims=3072input_img= Input(shape=input_shape, name='encoder_input')
x=Conv2D(128, 4, padding='same', activation='relu',strides=2)(input_img)
x=Conv2D(256, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(512, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(1024, 4, padding='same', activation='relu',strides=2)(x)
conv_shape = K.int_shape(x)
x=Flatten()(x)
x=Dense(3072, activation='relu')(x)
z_mean=Dense(latent_dims, name='latent_mean')(x)
z_sigma=Dense(latent_dims, name='latent_sigma')(x)def sampler(args):z_mean, z_sigma = argseps = K.random_normal(shape=(K.shape(z_mean)[0], K.int_shape(z_mean)[1]))return z_mean + K.exp(z_sigma / 2) * epsz = Lambda(sampler, output_shape=(latent_dims, ), name='z')([z_mean, z_sigma])encoder = Model(input_img, [z_mean, z_sigma, z], name='encoder')
print(encoder.summary())

 解码器模型构建
decoder_input = Input(shape=(latent_dims, ), name='decoder_input')
x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(256, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(128, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(3, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
decoder = Model(decoder_input, x, name='decoder')
decoder.summary()
z_decoded = decoder(z)class CustomLayer(keras.layers.Layer):def vae_loss(self, x, z_decoded):x = K.flatten(x)z_decoded = K.flatten(z_decoded)# Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)# KL divergencekl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mean) - K.exp(z_sigma), axis=-1)return K.mean(recon_loss + kl_loss)# add custom loss to the classdef call(self, inputs):x = inputs[0]z_decoded = inputs[1]loss = self.vae_loss(x, z_decoded)self.add_loss(loss, inputs=inputs)return x

 

整体模型构建
y = CustomLayer()([input_img, z_decoded])vae = Model(input_img, y, name='vae')
vae.compile(optimizer='adam', loss=None)
vae.summary()

 

模型训练

history=vae.fit(xtrain, verbose=2, epochs = 100, batch_size = 64, validation_split = 0.2)
 训练可视化
f = plt.figure(figsize=(10,7))
f.add_subplot()
#Adding Subplot
plt.plot(history.epoch, history.history['loss'], label = "loss") # Loss curve for training set
plt.plot(history.epoch, history.history['val_loss'], label = "val_loss") # Loss curve for validation setplt.title("Loss Curve",fontsize=18)
plt.xlabel("Epochs",fontsize=15)
plt.ylabel("Loss",fontsize=15)
plt.grid(alpha=0.3)
plt.legend()
plt.savefig("VAE_Loss_Trial5.png")
plt.show()

 中间编码特征可视化
mu, _, _ = encoder.predict(xtest)
#Plot dim1 and dim2 for mu
plt.figure(figsize=(10, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=ytest, cmap='brg')
plt.xlabel('dim 1')
plt.ylabel('dim 2')
plt.colorbar()
plt.show()
plt.savefig("VAE_Colourbar_Trial5.png")

 

数据增强生成
#RANDOM GENERATION
def generate():n=20figure = np.zeros((width *2 , height * 10, channels))#Create a Grid of latent variables, to be provided as inputs to decoder.predict
#Creating vectors within range -5 to 5 as that seems to be the range in latent spacefor k in range(2):for l in range(10):z_sample =random.rand(3072)z_out=np.array([z_sample])x_decoded = decoder.predict(z_out)digit = x_decoded[0].reshape(width, height, channels)figure[k * width: (k + 1) * width,l * height: (l + 1) * height] = digitplt.figure(figsize=(10, 10))
#Reshape for visualizationfig_shape = np.shape(figure)figure = figure.reshape((fig_shape[0], fig_shape[1],3))plt.imshow(figure, cmap='gnuplot2')plt.show()  plt.savefig("VAE_imagesgen_Trial5.png")

解码器图像重建
#IMAGE RECONSTRUCT USING TEST SET IMGS
def reconstruct():num_imgs = 6rand = np.random.randint(1, xtest.shape[0]-6) xtestsample = xtest[rand:rand+num_imgs]x_encoded = np.array(encoder.predict(xtestsample))latent_xtest=x_encoded[2]x_decoded = decoder.predict(latent_xtest)rows = 2 # defining no. of rows in figurecols = 3 # defining no. of colums in figurecell_size = 1.5f = plt.figure(figsize=(cell_size*cols,cell_size*rows*2)) # defining a figure f.tight_layout()for i in range(rows):for j in range(cols): f.add_subplot(rows*2,cols, (2*i*cols)+(j+1)) # adding sub plot to figure on each iterationplt.imshow(xtestsample[i*cols + j]) plt.axis("off")for j in range(cols): f.add_subplot(rows*2,cols,((2*i+1)*cols)+(j+1)) # adding sub plot to figure on each iterationplt.imshow(x_decoded[i*cols + j]) plt.axis("off")f.suptitle("Autoencoder Results - Cifar10",fontsize=18)plt.savefig("VAE_imagesrecons_Trial5.png")plt.show()

 

代码获取

已经附在文章底部,自行拿取。

项目开发,相关问题咨询,欢迎交流沟通。

这篇关于【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/874052

相关文章

C++对象布局及多态实现探索之内存布局(整理的很多链接)

本文通过观察对象的内存布局,跟踪函数调用的汇编代码。分析了C++对象内存的布局情况,虚函数的执行方式,以及虚继承,等等 文章链接:http://dev.yesky.com/254/2191254.shtml      论C/C++函数间动态内存的传递 (2005-07-30)   当你涉及到C/C++的核心编程的时候,你会无止境地与内存管理打交道。 文章链接:http://dev.yesky

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

电脑不小心删除的文件怎么恢复?4个必备恢复方法!

“刚刚在对电脑里的某些垃圾文件进行清理时,我一不小心误删了比较重要的数据。这些误删的数据还有机会恢复吗?希望大家帮帮我,非常感谢!” 在这个数字化飞速发展的时代,电脑早已成为我们日常生活和工作中不可或缺的一部分。然而,就像生活中的小插曲一样,有时我们可能会在不经意间犯下一些小错误,比如不小心删除了重要的文件。 当那份文件消失在眼前,仿佛被时间吞噬,我们不禁会心生焦虑。但别担心,就像每个问题

JAVA读取MongoDB中的二进制图片并显示在页面上

1:Jsp页面: <td><img src="${ctx}/mongoImg/show"></td> 2:xml配置: <?xml version="1.0" encoding="UTF-8"?><beans xmlns="http://www.springframework.org/schema/beans"xmlns:xsi="http://www.w3.org/2001

通过SSH隧道实现通过远程服务器上外网

搭建隧道 autossh -M 0 -f -D 1080 -C -N user1@remotehost##验证隧道是否生效,查看1080端口是否启动netstat -tuln | grep 1080## 测试ssh 隧道是否生效curl -x socks5h://127.0.0.1:1080 -I http://www.github.com 将autossh 设置为服务,隧道开机启动

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测

时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测 目录 时序预测 | MATLAB实现LSTM时间序列未来多步预测-递归预测基本介绍程序设计参考资料 基本介绍 MATLAB实现LSTM时间序列未来多步预测-递归预测。LSTM是一种含有LSTM区块(blocks)或其他的一种类神经网络,文献或其他资料中LSTM区块可能被描述成智能网络单元,因为

vue项目集成CanvasEditor实现Word在线编辑器

CanvasEditor实现Word在线编辑器 官网文档:https://hufe.club/canvas-editor-docs/guide/schema.html 源码地址:https://github.com/Hufe921/canvas-editor 前提声明: 由于CanvasEditor目前不支持vue、react 等框架开箱即用版,所以需要我们去Git下载源码,拿到其中两个主

android一键分享功能部分实现

为什么叫做部分实现呢,其实是我只实现一部分的分享。如新浪微博,那还有没去实现的是微信分享。还有一部分奇怪的问题:我QQ分享跟QQ空间的分享功能,我都没配置key那些都是原本集成就有的key也可以实现分享,谁清楚的麻烦详解下。 实现分享功能我们可以去www.mob.com这个网站集成。免费的,而且还有短信验证功能。等这分享研究完后就研究下短信验证功能。 开始实现步骤(新浪分享,以下是本人自己实现