简单易上手的生成对抗网络

2024-09-02 06:12

本文主要是介绍简单易上手的生成对抗网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

模型原理

生成对抗网络是指一类采用对抗训练方式进行学习的深度生成模型,包含的判别网络生成网络都可以根据不同的生成任务使用不同的网络结构。

生成器: 通过机器生成数据,最终目的是骗过判别器。
判别器: 判断这张图像是真实的还是机器生成的,目的是找出生成器做的假数据。

构建GAN模型的基本逻辑: 现实问题需求→建立实现功能的GAN框架(编程)→训练GAN(生成网络、对抗网络)→成熟的GAN模型→应用。

GAN训练过程:
生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。

GAN的发展历程

  1. GAN的基本思想起源于2014年,由伊恩·古德费洛等人首次提出。
  2. DCGAN,它在生成器和判别器中都使用了卷积层,取得了更好的图像生成效果。
  3. ConditionalGAN,通过引入条件信息指导生成器生成特定类型的数据。 Wasserstein
  4. WGAN使用Wasserstein距离作为损失函数,为GAN的训练提供了更稳定的优化方法,提高了生成样本的质量。

代码实现

DCGAN模型:

generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[100]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",activation="tanh"))discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",activation=LeakyReLU(0.3),input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))

模型训练:

GAN =Sequential([generator,discriminator])
discriminator.compile(optimizer='adam',loss='binary_crossentropy')
discriminator.trainable = FalseGAN.compile(optimizer='adam',loss='binary_crossentropy')epochs = 150 
batch_size = 100
noise_shape=100with tf.device('/gpu:0'):for epoch in range(epochs):print(f"Currently on Epoch {epoch+1}")for i in range(X_train.shape[0]//batch_size):if (i+1)%50 == 0:print(f"\tCurrently on batch number {i+1} of {X_train.shape[0]//batch_size}")noise=np.random.normal(size=[batch_size,noise_shape])gen_image = generator.predict_on_batch(noise)train_dataset = X_train[i*batch_size:(i+1)*batch_size]train_label=np.ones(shape=(batch_size,1))discriminator.trainable = Trued_loss_real=discriminator.train_on_batch(train_dataset,train_label)train_label=np.zeros(shape=(batch_size,1))d_loss_fake=discriminator.train_on_batch(gen_image,train_label)noise=np.random.normal(size=[batch_size,noise_shape])train_label=np.ones(shape=(batch_size,1))discriminator.trainable = False #while training the generator as combined model,discriminator training should be turned offd_g_loss_batch =GAN.train_on_batch(noise, train_label)if epoch % 10 == 0:samples = 10x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))for k in range(samples):plt.subplot(2, 5, k+1)plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')plt.xticks([])plt.yticks([])plt.tight_layout()plt.show()print('Training is complete')

使用np.random.normal生成的噪声被作为输入给发生器:

noise=np.random.normal(loc=0, scale=1, size=(100,noise_shape))
gen_image = generator.predict(noise)
plt.imshow(noise)
plt.title('DCGAN Noise')

这篇关于简单易上手的生成对抗网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1129214

相关文章

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

基于Python实现一个简单的题库与在线考试系统

《基于Python实现一个简单的题库与在线考试系统》在当今信息化教育时代,在线学习与考试系统已成为教育技术领域的重要组成部分,本文就来介绍一下如何使用Python和PyQt5框架开发一个名为白泽题库系... 目录概述功能特点界面展示系统架构设计类结构图Excel题库填写格式模板题库题目填写格式表核心数据结构

Python实现自动化Word文档样式复制与内容生成

《Python实现自动化Word文档样式复制与内容生成》在办公自动化领域,高效处理Word文档的样式和内容复制是一个常见需求,本文将展示如何利用Python的python-docx库实现... 目录一、为什么需要自动化 Word 文档处理二、核心功能实现:样式与表格的深度复制1. 表格复制(含样式与内容)2

python如何生成指定文件大小

《python如何生成指定文件大小》:本文主要介绍python如何生成指定文件大小的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录python生成指定文件大小方法一(速度最快)方法二(中等速度)方法三(生成可读文本文件–较慢)方法四(使用内存映射高效生成

C/C++ chrono简单使用场景示例详解

《C/C++chrono简单使用场景示例详解》:本文主要介绍C/C++chrono简单使用场景示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友... 目录chrono使用场景举例1 输出格式化字符串chrono使用场景China编程举例1 输出格式化字符串示

Maven项目中集成数据库文档生成工具的操作步骤

《Maven项目中集成数据库文档生成工具的操作步骤》在Maven项目中,可以通过集成数据库文档生成工具来自动生成数据库文档,本文为大家整理了使用screw-maven-plugin(推荐)的完... 目录1. 添加插件配置到 pom.XML2. 配置数据库信息3. 执行生成命令4. 高级配置选项5. 注意事

MybatisX快速生成增删改查的方法示例

《MybatisX快速生成增删改查的方法示例》MybatisX是基于IDEA的MyBatis/MyBatis-Plus开发插件,本文主要介绍了MybatisX快速生成增删改查的方法示例,文中通过示例代... 目录1 安装2 基本功能2.1 XML跳转2.2 代码生成2.2.1 生成.xml中的sql语句头2

Linux网络配置之网桥和虚拟网络的配置指南

《Linux网络配置之网桥和虚拟网络的配置指南》这篇文章主要为大家详细介绍了Linux中配置网桥和虚拟网络的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、网桥的配置在linux系统中配置一个新的网桥主要涉及以下几个步骤:1.为yum仓库做准备,安装组件epel-re

windows和Linux安装Jmeter与简单使用方式

《windows和Linux安装Jmeter与简单使用方式》:本文主要介绍windows和Linux安装Jmeter与简单使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录Windows和linux安装Jmeter与简单使用一、下载安装包二、JDK安装1.windows设

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o