好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体

2023-11-02 02:11

本文主要是介绍好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体

  • 学习前言
  • 什么是GAN
  • 神经网络构建
    • 1、Generator
    • 2、Discriminator
  • 训练思路
  • 实现全部代码:

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

什么是GAN

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

在GAN模型中,一般存在两个模块
分别是生成模型(Generative Model)和判别模型(Discriminative Model);二者的互相博弈与学习会产生相当好的输出

原始 GAN 理论中,并不要求生成模型和判别模型都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为生成模型和判别模型 。

一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想

其实简单来讲,一般情况下,GAN就是创建两个神经网络,一个是生成模型,一个是判别模型

生成模型的输入一行正态分布随机数,输出可以被认为是一张图片(或者其它需要被判定真伪的东西)。
判别模型的输入一张图片(或者其它需要被判定真伪的东西),输出是输入进来的图片是否是真实的(0或者1)。

生成模型不断训练的目的是生成 让判别模型无法判断真伪的输出。
判别模型不断训练的的目的是判断出输入图片的真伪
在这里插入图片描述

神经网络构建

1、Generator

生成网络的目标是输入一行正态分布随机数,生成mnist手写体图片,因此它的输入是一个长度为N的一维的向量,输出一个28,28,1维的图片。

def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)

2、Discriminator

判别模型的目的是根据输入的图片判断出真伪。因此它的输入一个28,28,1维的图片,输出是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。

def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

训练思路

GAN的训练分为如下几个步骤:
1、随机选取batch_size个真实的图片。
2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。
4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”)。

实现全部代码:

from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adamimport matplotlib.pyplot as pltimport sys
import os
import numpy as npclass GAN():def __init__(self):# --------------------------------- ##   行28,列28,也就是mnist的shape# --------------------------------- #self.img_rows = 28self.img_cols = 28self.channels = 1# 28,28,1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.generator = self.build_generator()gan_input = Input(shape=(self.latent_dim,))img = self.generator(gan_input)# 在训练generate的时候不训练discriminatorself.discriminator.trainable = False# 对生成的假图片进行预测validity = self.discriminator(img)self.combined = Model(gan_input, validity)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):# 获得数据(X_train, _), (_, _) = mnist.load_data()# 进行标准化X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# 创建标签valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# --------------------------- ##   随机选取batch_size个图片#   对discriminator进行训练# --------------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# --------------------------- ##  训练generator# --------------------------- #noise = np.random.normal(0, 1, (batch_size, self.latent_dim))g_loss = self.combined.train_on_batch(noise, valid)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = GAN()gan.train(epochs=30000, batch_size=256, sample_interval=200)

实现效果为:
在这里插入图片描述

这篇关于好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/327385

相关文章

基于Qt开发一个简单的OFD阅读器

《基于Qt开发一个简单的OFD阅读器》这篇文章主要为大家详细介绍了如何使用Qt框架开发一个功能强大且性能优异的OFD阅读器,文中的示例代码讲解详细,有需要的小伙伴可以参考一下... 目录摘要引言一、OFD文件格式解析二、文档结构解析三、页面渲染四、用户交互五、性能优化六、示例代码七、未来发展方向八、结论摘要

Mycat搭建分库分表方式

《Mycat搭建分库分表方式》文章介绍了如何使用分库分表架构来解决单表数据量过大带来的性能和存储容量限制的问题,通过在一对主从复制节点上配置数据源,并使用分片算法将数据分配到不同的数据库表中,可以有效... 目录分库分表解决的问题分库分表架构添加数据验证结果 总结分库分表解决的问题单表数据量过大带来的性能

Java汇编源码如何查看环境搭建

《Java汇编源码如何查看环境搭建》:本文主要介绍如何在IntelliJIDEA开发环境中搭建字节码和汇编环境,以便更好地进行代码调优和JVM学习,首先,介绍了如何配置IntelliJIDEA以方... 目录一、简介二、在IDEA开发环境中搭建汇编环境2.1 在IDEA中搭建字节码查看环境2.1.1 搭建步

Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)

《Python基于火山引擎豆包大模型搭建QQ机器人详细教程(2024年最新)》:本文主要介绍Python基于火山引擎豆包大模型搭建QQ机器人详细的相关资料,包括开通模型、配置APIKEY鉴权和SD... 目录豆包大模型概述开通模型付费安装 SDK 环境配置 API KEY 鉴权Ark 模型接口Prompt

详解Java中如何使用JFreeChart生成甘特图

《详解Java中如何使用JFreeChart生成甘特图》甘特图是一种流行的项目管理工具,用于显示项目的进度和任务分配,在Java开发中,JFreeChart是一个强大的开源图表库,能够生成各种类型的图... 目录引言一、JFreeChart简介二、准备工作三、创建甘特图1. 定义数据集2. 创建甘特图3.

MyBatis框架实现一个简单的数据查询操作

《MyBatis框架实现一个简单的数据查询操作》本文介绍了MyBatis框架下进行数据查询操作的详细步骤,括创建实体类、编写SQL标签、配置Mapper、开启驼峰命名映射以及执行SQL语句等,感兴趣的... 基于在前面几章我们已经学习了对MyBATis进行环境配置,并利用SqlSessionFactory核

鸿蒙开发搭建flutter适配的开发环境

《鸿蒙开发搭建flutter适配的开发环境》文章详细介绍了在Windows系统上如何创建和运行鸿蒙Flutter项目,包括使用flutterdoctor检测环境、创建项目、编译HAP包以及在真机上运... 目录环境搭建创建运行项目打包项目总结环境搭建1.安装 DevEco Studio NEXT IDE

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

hdu2289(简单二分)

虽说是简单二分,但是我还是wa死了  题意:已知圆台的体积,求高度 首先要知道圆台体积怎么求:设上下底的半径分别为r1,r2,高为h,V = PI*(r1*r1+r1*r2+r2*r2)*h/3 然后以h进行二分 代码如下: #include<iostream>#include<algorithm>#include<cstring>#include<stack>#includ