本文主要是介绍好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体
- 学习前言
- 什么是GAN
- 神经网络构建
- 1、Generator
- 2、Discriminator
- 训练思路
- 实现全部代码:
学习前言
我又死了我又死了我又死了!
什么是GAN
生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。
在GAN模型中,一般存在两个模块:
分别是生成模型(Generative Model)和判别模型(Discriminative Model);二者的互相博弈与学习将会产生相当好的输出。
原始 GAN 理论中,并不要求生成模型和判别模型都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为生成模型和判别模型 。
一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
其实简单来讲,一般情况下,GAN就是创建两个神经网络,一个是生成模型,一个是判别模型。
生成模型的输入是一行正态分布随机数,输出可以被认为是一张图片(或者其它需要被判定真伪的东西)。
判别模型的输入是一张图片(或者其它需要被判定真伪的东西),输出是输入进来的图片是否是真实的(0或者1)。
生成模型不断训练的目的是生成 让判别模型无法判断真伪的输出。
判别模型不断训练的的目的是判断出输入图片的真伪。
神经网络构建
1、Generator
生成网络的目标是输入一行正态分布随机数,生成mnist手写体图片,因此它的输入是一个长度为N的一维的向量,输出一个28,28,1维的图片。
def build_generator(self):# --------------------------------- ## 生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)
2、Discriminator
判别模型的目的是根据输入的图片判断出真伪。因此它的输入一个28,28,1维的图片,输出是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。
def build_discriminator(self):# ----------------------------------- ## 评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)
训练思路
GAN的训练分为如下几个步骤:
1、随机选取batch_size个真实的图片。
2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。
4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”)。
实现全部代码:
from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adamimport matplotlib.pyplot as pltimport sys
import os
import numpy as npclass GAN():def __init__(self):# --------------------------------- ## 行28,列28,也就是mnist的shape# --------------------------------- #self.img_rows = 28self.img_cols = 28self.channels = 1# 28,28,1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.generator = self.build_generator()gan_input = Input(shape=(self.latent_dim,))img = self.generator(gan_input)# 在训练generate的时候不训练discriminatorself.discriminator.trainable = False# 对生成的假图片进行预测validity = self.discriminator(img)self.combined = Model(gan_input, validity)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):# --------------------------------- ## 生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):# ----------------------------------- ## 评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):# 获得数据(X_train, _), (_, _) = mnist.load_data()# 进行标准化X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# 创建标签valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# --------------------------- ## 随机选取batch_size个图片# 对discriminator进行训练# --------------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# --------------------------- ## 训练generator# --------------------------- #noise = np.random.normal(0, 1, (batch_size, self.latent_dim))g_loss = self.combined.train_on_batch(noise, valid)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = GAN()gan.train(epochs=30000, batch_size=256, sample_interval=200)
实现效果为:
这篇关于好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!