好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体

2023-11-02 02:11

本文主要是介绍好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体

  • 学习前言
  • 什么是GAN
  • 神经网络构建
    • 1、Generator
    • 2、Discriminator
  • 训练思路
  • 实现全部代码:

学习前言

我又死了我又死了我又死了!
在这里插入图片描述

什么是GAN

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

在GAN模型中,一般存在两个模块
分别是生成模型(Generative Model)和判别模型(Discriminative Model);二者的互相博弈与学习会产生相当好的输出

原始 GAN 理论中,并不要求生成模型和判别模型都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为生成模型和判别模型 。

一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想

其实简单来讲,一般情况下,GAN就是创建两个神经网络,一个是生成模型,一个是判别模型

生成模型的输入一行正态分布随机数,输出可以被认为是一张图片(或者其它需要被判定真伪的东西)。
判别模型的输入一张图片(或者其它需要被判定真伪的东西),输出是输入进来的图片是否是真实的(0或者1)。

生成模型不断训练的目的是生成 让判别模型无法判断真伪的输出。
判别模型不断训练的的目的是判断出输入图片的真伪
在这里插入图片描述

神经网络构建

1、Generator

生成网络的目标是输入一行正态分布随机数,生成mnist手写体图片,因此它的输入是一个长度为N的一维的向量,输出一个28,28,1维的图片。

def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)

2、Discriminator

判别模型的目的是根据输入的图片判断出真伪。因此它的输入一个28,28,1维的图片,输出是0到1之间的数,1代表判断这个图片是真的,0代表判断这个图片是假的。

def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)

训练思路

GAN的训练分为如下几个步骤:
1、随机选取batch_size个真实的图片。
2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。
4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”)。

实现全部代码:

from __future__ import print_function, divisionfrom keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adamimport matplotlib.pyplot as pltimport sys
import os
import numpy as npclass GAN():def __init__(self):# --------------------------------- ##   行28,列28,也就是mnist的shape# --------------------------------- #self.img_rows = 28self.img_cols = 28self.channels = 1# 28,28,1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.generator = self.build_generator()gan_input = Input(shape=(self.latent_dim,))img = self.generator(gan_input)# 在训练generate的时候不训练discriminatorself.discriminator.trainable = False# 对生成的假图片进行预测validity = self.discriminator(img)self.combined = Model(gan_input, validity)self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1, activation='sigmoid'))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):# 获得数据(X_train, _), (_, _) = mnist.load_data()# 进行标准化X_train = X_train / 127.5 - 1.X_train = np.expand_dims(X_train, axis=3)# 创建标签valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# --------------------------- ##   随机选取batch_size个图片#   对discriminator进行训练# --------------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# --------------------------- ##  训练generator# --------------------------- #noise = np.random.normal(0, 1, (batch_size, self.latent_dim))g_loss = self.combined.train_on_batch(noise, valid)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = GAN()gan.train(epochs=30000, batch_size=256, sample_interval=200)

实现效果为:
在这里插入图片描述

这篇关于好像还挺好玩的GAN1——Keras搭建简单GAN生成MNIST手写体的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/327385

相关文章

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

hdu2289(简单二分)

虽说是简单二分,但是我还是wa死了  题意:已知圆台的体积,求高度 首先要知道圆台体积怎么求:设上下底的半径分别为r1,r2,高为h,V = PI*(r1*r1+r1*r2+r2*r2)*h/3 然后以h进行二分 代码如下: #include<iostream>#include<algorithm>#include<cstring>#include<stack>#includ

usaco 1.3 Prime Cryptarithm(简单哈希表暴搜剪枝)

思路: 1. 用一个 hash[ ] 数组存放输入的数字,令 hash[ tmp ]=1 。 2. 一个自定义函数 check( ) ,检查各位是否为输入的数字。 3. 暴搜。第一行数从 100到999,第二行数从 10到99。 4. 剪枝。 代码: /*ID: who jayLANG: C++TASK: crypt1*/#include<stdio.h>bool h

pdfmake生成pdf的使用

实际项目中有时会有根据填写的表单数据或者其他格式的数据,将数据自动填充到pdf文件中根据固定模板生成pdf文件的需求 文章目录 利用pdfmake生成pdf文件1.下载安装pdfmake第三方包2.封装生成pdf文件的共用配置3.生成pdf文件的文件模板内容4.调用方法生成pdf 利用pdfmake生成pdf文件 1.下载安装pdfmake第三方包 npm i pdfma

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

poj 1287 Networking(prim or kruscal最小生成树)

题意给你点与点间距离,求最小生成树。 注意点是,两点之间可能有不同的路,输入的时候选择最小的,和之前有道最短路WA的题目类似。 prim代码: #include<stdio.h>const int MaxN = 51;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int P;int prim(){bool vis[MaxN];

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

搭建Kafka+zookeeper集群调度

前言 硬件环境 172.18.0.5        kafkazk1        Kafka+zookeeper                Kafka Broker集群 172.18.0.6        kafkazk2        Kafka+zookeeper                Kafka Broker集群 172.18.0.7        kafkazk3

uva 10387 Billiard(简单几何)

题意是一个球从矩形的中点出发,告诉你小球与矩形两条边的碰撞次数与小球回到原点的时间,求小球出发时的角度和小球的速度。 简单的几何问题,小球每与竖边碰撞一次,向右扩展一个相同的矩形;每与横边碰撞一次,向上扩展一个相同的矩形。 可以发现,扩展矩形的路径和在当前矩形中的每一段路径相同,当小球回到出发点时,一条直线的路径刚好经过最后一个扩展矩形的中心点。 最后扩展的路径和横边竖边恰好组成一个直