训练生成手写体数字 对抗神经网络

2024-01-02 01:44

本文主要是介绍训练生成手写体数字 对抗神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

下面是一个使用TensorFlow和Keras的生成对抗网络(GAN)的基本示例,用于生成手写体数字。这个示例基于MNIST数据集。

 

我没有包括所有可能的最佳实践,如模型保存、加载、超参数调整、日志记录等。

首先,确保你安装了所需的库,特别是TensorFlow:

pip install tensorflow

接下来是GAN的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # 将像素值缩放到[-1, 1]# GAN参数
img_rows, img_cols, channels = 28, 28, 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100# 生成器
def build_generator():model = Sequential()model.add(Dense(256, input_dim=latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return model# 判别器
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))return model# 编译判别器
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])# 编译生成器
generator = build_generator()# 生成器输入噪声并生成图像
z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)# 对于组合模型,我们只训练生成器
discriminator.trainable = False# 判别器尝试区分真实和生成的图像
valid = discriminator(img)# 组合模型(叠加生成器和判别器)
combined = tf.keras.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
def train(epochs, batch_size=128, save_interval=50):half_batch = batch_size // 2for epoch in range(epochs):# 随机选择一半的图像idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成一半的新图像noise = np.random.normal(0, 1, (half_batch, latent_dim))gen_imgs = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))valid_y = np.array([1] * batch_size)g_loss = combined.train_on_batch(noise, valid_y)# 打印进度print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")# 如果到了保存间隔,则保存生成的图像样本if epoch % save_interval == 0:save_imgs(epoch)# 保存图像
def save_imgs(epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, latent_dim))gen_imgs = generator.predict(noise)# 缩放图片从[-1, 1]到[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig(f"mnist_{epoch}.png")plt.close()# 训练GAN
train(epochs=10000, batch_size=32, save_interval=1000)

在这个代码中,首先定义了生成器和判别器的架构,然后将它们结合起来形成一个GAN网络。

train函数负责训练过程,它交替地训练判别器和生成器。save_imgs函数用于保存生成的图像,以便我们可以查看GAN在训练过程中的进步。

这是一个非常基础的GAN实现,对于实际应用,你可能需要进行很多调整和优化,包括更复杂的模型架构、更细致的训练过程控制、超参数调整等。

逐行注释的代码,解释了每一步的作用:

# 导入所需的库
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()# 数据预处理,将像素值归一化到[-1, 1]的范围
X_train = X_train / 127.5 - 1.0
# 增加一个维度,使图像有单通道,即形状为(batch_size, 28, 28, 1)
X_train = np.expand_dims(X_train, axis=-1)# 定义生成器模型
def build_generator():model = Sequential()model.add(Dense(256, input_shape=(100,)))  # 输入层,输入维度为100(噪声向量)model.add(LeakyReLU(alpha=0.2))  # 使用LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(1024))  # 第三层,1024个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))  # 输出层,输出与图像像素数相同的单元数model.add(Reshape((28, 28, 1)))  # 将输出重塑为28x28图像return model# 定义判别器模型
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=(28, 28, 1)))  # 输入层,将28x28图像展平model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(256))  # 第三层,256个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(1, activation='sigmoid'))  # 输出层,一个单元输出0到1之间的值return model# 编译判别器和生成器
discriminator = build_discriminator()
# 使用二元交叉熵作为损失函数,Adam优化器,以及准确度评估
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
generator = build_generator()# GAN模型组合
z = tf.keras.Input(shape=(100,))  # 输入层,100维噪声向量
img = generator(z)  # 生成器生成图像
discriminator.trainable = False  # 在训练生成器时冻结判别器的权重
valid = discriminator(img)  # 判别器对生成的图像进行评估
combined = tf.keras.Model(z, valid)  # 组合模型,输入是噪声,输出是判别器的评估结果
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
epochs = 10000  # 训练轮数
batch_size = 32  # 批量大小
save_interval = 1000  # 保存图片的间隔
noise_dim = 100  # 噪声向量的维度
half_batch = batch_size // 2  # 半批量大小
valid = np.ones((half_batch, 1))  # 真实图片标签
fake = np.zeros((half_batch, 1))  # 伪造图片标签for epoch in range(epochs):# 随机选择真实图片idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成噪声noise = np.random.normal(0, 1, (half_batch, noise_dim))# 使用噪声生成伪造图片gen_imgs = generator(noise, training=False)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 生成更多噪声noise = np.random.normal(0, 1, (batch_size, noise_dim))# 训练生成器g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))# 如果达到保存间隔,打印损失并保存生成的图片if epoch % save_interval == 0:print("Epoch {}/{} [D loss: {:.4f}, acc.: {:.2f}%] [G loss: {:.4f}]".format(epoch, epochs, d_loss[0], 100 * d_loss[1], g_loss))save_imgs(generator, epoch, noise_dim)# 定义函数以保存生成的手写数字图像
def save_imgs(generator, epoch, noise_dim):r, c = 5, 5  # 生成5x5网格的图片noise = np.random.normal(0, 1, (r * c, noise_dim))  # 生成噪声gen_imgs = generator(noise, training=False)  # 使用噪声生成图片gen_imgs = 0.5 * gen_imgs + 0.5  # 将图片的像素值从[-1, 1]缩放到[0, 1]fig, axs = plt.subplots(r, c)  # 创建子图cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')  # 显示生成的图片axs[i, j].axis('off')  # 关闭坐标轴cnt += 1fig.savefig("mnist_%d.png" % epoch)  # 保存生成的图片plt.close()  # 关闭图形显示窗口# 选择性地保存生成器模型
generator.save('mnist_generator.h5')

这样的注释有助于理解代码的每一步,特别是对于初学者来说,可以更好地理解GAN的工作原理和实现细节。

版权所有 © 2023 王一帆。除非另有说明,本作品采用[知识共享 署名-非衍生作品 4.0 国际许可协议](https://creativecommons.org/licenses/by-nd/4.0/)进行许可。

这篇关于训练生成手写体数字 对抗神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561028

相关文章

Python实现特殊字符判断并去掉非字母和数字的特殊字符

《Python实现特殊字符判断并去掉非字母和数字的特殊字符》在Python中,可以通过多种方法来判断字符串中是否包含非字母、数字的特殊字符,并将这些特殊字符去掉,本文为大家整理了一些常用的,希望对大家... 目录1. 使用正则表达式判断字符串中是否包含特殊字符去掉字符串中的特殊字符2. 使用 str.isa

IDEA自动生成注释模板的配置教程

《IDEA自动生成注释模板的配置教程》本文介绍了如何在IntelliJIDEA中配置类和方法的注释模板,包括自动生成项目名称、包名、日期和时间等内容,以及如何定制参数和返回值的注释格式,需要的朋友可以... 目录项目场景配置方法类注释模板定义类开头的注释步骤类注释效果方法注释模板定义方法开头的注释步骤方法注

Python如何自动生成环境依赖包requirements

《Python如何自动生成环境依赖包requirements》:本文主要介绍Python如何自动生成环境依赖包requirements问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录生成当前 python 环境 安装的所有依赖包1、命令2、常见问题只生成当前 项目 的所有依赖包1、

MySQL中动态生成SQL语句去掉所有字段的空格的操作方法

《MySQL中动态生成SQL语句去掉所有字段的空格的操作方法》在数据库管理过程中,我们常常会遇到需要对表中字段进行清洗和整理的情况,本文将详细介绍如何在MySQL中动态生成SQL语句来去掉所有字段的空... 目录在mysql中动态生成SQL语句去掉所有字段的空格准备工作原理分析动态生成SQL语句在MySQL

Java利用docx4j+Freemarker生成word文档

《Java利用docx4j+Freemarker生成word文档》这篇文章主要为大家详细介绍了Java如何利用docx4j+Freemarker生成word文档,文中的示例代码讲解详细,感兴趣的小伙伴... 目录技术方案maven依赖创建模板文件实现代码技术方案Java 1.8 + docx4j + Fr

Java编译生成多个.class文件的原理和作用

《Java编译生成多个.class文件的原理和作用》作为一名经验丰富的开发者,在Java项目中执行编译后,可能会发现一个.java源文件有时会产生多个.class文件,从技术实现层面详细剖析这一现象... 目录一、内部类机制与.class文件生成成员内部类(常规内部类)局部内部类(方法内部类)匿名内部类二、

使用Jackson进行JSON生成与解析的新手指南

《使用Jackson进行JSON生成与解析的新手指南》这篇文章主要为大家详细介绍了如何使用Jackson进行JSON生成与解析处理,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1. 核心依赖2. 基础用法2.1 对象转 jsON(序列化)2.2 JSON 转对象(反序列化)3.

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

在java中如何将inputStream对象转换为File对象(不生成本地文件)

《在java中如何将inputStream对象转换为File对象(不生成本地文件)》:本文主要介绍在java中如何将inputStream对象转换为File对象(不生成本地文件),具有很好的参考价... 目录需求说明问题解决总结需求说明在后端中通过POI生成Excel文件流,将输出流(outputStre

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的