训练生成手写体数字 对抗神经网络

2024-01-02 01:44

本文主要是介绍训练生成手写体数字 对抗神经网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

下面是一个使用TensorFlow和Keras的生成对抗网络(GAN)的基本示例,用于生成手写体数字。这个示例基于MNIST数据集。

 

我没有包括所有可能的最佳实践,如模型保存、加载、超参数调整、日志记录等。

首先,确保你安装了所需的库,特别是TensorFlow:

pip install tensorflow

接下来是GAN的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0 * 2 - 1  # 将像素值缩放到[-1, 1]# GAN参数
img_rows, img_cols, channels = 28, 28, 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100# 生成器
def build_generator():model = Sequential()model.add(Dense(256, input_dim=latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(img_shape), activation='tanh'))model.add(Reshape(img_shape))return model# 判别器
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))model.add(Dense(1, activation='sigmoid'))return model# 编译判别器
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])# 编译生成器
generator = build_generator()# 生成器输入噪声并生成图像
z = tf.keras.Input(shape=(latent_dim,))
img = generator(z)# 对于组合模型,我们只训练生成器
discriminator.trainable = False# 判别器尝试区分真实和生成的图像
valid = discriminator(img)# 组合模型(叠加生成器和判别器)
combined = tf.keras.Model(z, valid)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
def train(epochs, batch_size=128, save_interval=50):half_batch = batch_size // 2for epoch in range(epochs):# 随机选择一半的图像idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成一半的新图像noise = np.random.normal(0, 1, (half_batch, latent_dim))gen_imgs = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))valid_y = np.array([1] * batch_size)g_loss = combined.train_on_batch(noise, valid_y)# 打印进度print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")# 如果到了保存间隔,则保存生成的图像样本if epoch % save_interval == 0:save_imgs(epoch)# 保存图像
def save_imgs(epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, latent_dim))gen_imgs = generator.predict(noise)# 缩放图片从[-1, 1]到[0, 1]gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig(f"mnist_{epoch}.png")plt.close()# 训练GAN
train(epochs=10000, batch_size=32, save_interval=1000)

在这个代码中,首先定义了生成器和判别器的架构,然后将它们结合起来形成一个GAN网络。

train函数负责训练过程,它交替地训练判别器和生成器。save_imgs函数用于保存生成的图像,以便我们可以查看GAN在训练过程中的进步。

这是一个非常基础的GAN实现,对于实际应用,你可能需要进行很多调整和优化,包括更复杂的模型架构、更细致的训练过程控制、超参数调整等。

逐行注释的代码,解释了每一步的作用:

# 导入所需的库
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()# 数据预处理,将像素值归一化到[-1, 1]的范围
X_train = X_train / 127.5 - 1.0
# 增加一个维度,使图像有单通道,即形状为(batch_size, 28, 28, 1)
X_train = np.expand_dims(X_train, axis=-1)# 定义生成器模型
def build_generator():model = Sequential()model.add(Dense(256, input_shape=(100,)))  # 输入层,输入维度为100(噪声向量)model.add(LeakyReLU(alpha=0.2))  # 使用LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(1024))  # 第三层,1024个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(BatchNormalization(momentum=0.8))  # 批量归一化model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))  # 输出层,输出与图像像素数相同的单元数model.add(Reshape((28, 28, 1)))  # 将输出重塑为28x28图像return model# 定义判别器模型
def build_discriminator():model = Sequential()model.add(Flatten(input_shape=(28, 28, 1)))  # 输入层,将28x28图像展平model.add(Dense(512))  # 第二层,512个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(256))  # 第三层,256个单元model.add(LeakyReLU(alpha=0.2))  # LeakyReLU激活函数model.add(Dense(1, activation='sigmoid'))  # 输出层,一个单元输出0到1之间的值return model# 编译判别器和生成器
discriminator = build_discriminator()
# 使用二元交叉熵作为损失函数,Adam优化器,以及准确度评估
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
generator = build_generator()# GAN模型组合
z = tf.keras.Input(shape=(100,))  # 输入层,100维噪声向量
img = generator(z)  # 生成器生成图像
discriminator.trainable = False  # 在训练生成器时冻结判别器的权重
valid = discriminator(img)  # 判别器对生成的图像进行评估
combined = tf.keras.Model(z, valid)  # 组合模型,输入是噪声,输出是判别器的评估结果
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))# 训练GAN
epochs = 10000  # 训练轮数
batch_size = 32  # 批量大小
save_interval = 1000  # 保存图片的间隔
noise_dim = 100  # 噪声向量的维度
half_batch = batch_size // 2  # 半批量大小
valid = np.ones((half_batch, 1))  # 真实图片标签
fake = np.zeros((half_batch, 1))  # 伪造图片标签for epoch in range(epochs):# 随机选择真实图片idx = np.random.randint(0, X_train.shape[0], half_batch)imgs = X_train[idx]# 生成噪声noise = np.random.normal(0, 1, (half_batch, noise_dim))# 使用噪声生成伪造图片gen_imgs = generator(noise, training=False)# 训练判别器d_loss_real = discriminator.train_on_batch(imgs, valid)d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 生成更多噪声noise = np.random.normal(0, 1, (batch_size, noise_dim))# 训练生成器g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))# 如果达到保存间隔,打印损失并保存生成的图片if epoch % save_interval == 0:print("Epoch {}/{} [D loss: {:.4f}, acc.: {:.2f}%] [G loss: {:.4f}]".format(epoch, epochs, d_loss[0], 100 * d_loss[1], g_loss))save_imgs(generator, epoch, noise_dim)# 定义函数以保存生成的手写数字图像
def save_imgs(generator, epoch, noise_dim):r, c = 5, 5  # 生成5x5网格的图片noise = np.random.normal(0, 1, (r * c, noise_dim))  # 生成噪声gen_imgs = generator(noise, training=False)  # 使用噪声生成图片gen_imgs = 0.5 * gen_imgs + 0.5  # 将图片的像素值从[-1, 1]缩放到[0, 1]fig, axs = plt.subplots(r, c)  # 创建子图cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')  # 显示生成的图片axs[i, j].axis('off')  # 关闭坐标轴cnt += 1fig.savefig("mnist_%d.png" % epoch)  # 保存生成的图片plt.close()  # 关闭图形显示窗口# 选择性地保存生成器模型
generator.save('mnist_generator.h5')

这样的注释有助于理解代码的每一步,特别是对于初学者来说,可以更好地理解GAN的工作原理和实现细节。

版权所有 © 2023 王一帆。除非另有说明,本作品采用[知识共享 署名-非衍生作品 4.0 国际许可协议](https://creativecommons.org/licenses/by-nd/4.0/)进行许可。

这篇关于训练生成手写体数字 对抗神经网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/561028

相关文章

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

在java中如何将inputStream对象转换为File对象(不生成本地文件)

《在java中如何将inputStream对象转换为File对象(不生成本地文件)》:本文主要介绍在java中如何将inputStream对象转换为File对象(不生成本地文件),具有很好的参考价... 目录需求说明问题解决总结需求说明在后端中通过POI生成Excel文件流,将输出流(outputStre

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

java字符串数字补齐位数详解

《java字符串数字补齐位数详解》:本文主要介绍java字符串数字补齐位数,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Java字符串数字补齐位数一、使用String.format()方法二、Apache Commons Lang库方法三、Java 11+的St

C/C++随机数生成的五种方法

《C/C++随机数生成的五种方法》C++作为一种古老的编程语言,其随机数生成的方法已经经历了多次的变革,早期的C++版本使用的是rand()函数和RAND_MAX常量,这种方法虽然简单,但并不总是提供... 目录C/C++ 随机数生成方法1. 使用 rand() 和 srand()2. 使用 <random

Flask 验证码自动生成的实现示例

《Flask验证码自动生成的实现示例》本文主要介绍了Flask验证码自动生成的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习... 目录生成图片以及结果处理验证码蓝图html页面展示想必验证码大家都有所了解,但是可以自己定义图片验证码

Python如何在Word中生成多种不同类型的图表

《Python如何在Word中生成多种不同类型的图表》Word文档中插入图表不仅能直观呈现数据,还能提升文档的可读性和专业性,本文将介绍如何使用Python在Word文档中创建和自定义各种图表,需要的... 目录在Word中创建柱形图在Word中创建条形图在Word中创建折线图在Word中创建饼图在Word

nginx生成自签名SSL证书配置HTTPS的实现

《nginx生成自签名SSL证书配置HTTPS的实现》本文主要介绍在Nginx中生成自签名SSL证书并配置HTTPS,包括安装Nginx、创建证书、配置证书以及测试访问,具有一定的参考价值,感兴趣的可... 目录一、安装nginx二、创建证书三、配置证书并验证四、测试一、安装nginxnginx必须有"-

Java实战之利用POI生成Excel图表

《Java实战之利用POI生成Excel图表》ApachePOI是Java生态中处理Office文档的核心工具,这篇文章主要为大家详细介绍了如何在Excel中创建折线图,柱状图,饼图等常见图表,需要的... 目录一、环境配置与依赖管理二、数据源准备与工作表构建三、图表生成核心步骤1. 折线图(Line Ch

浅析如何使用Swagger生成带权限控制的API文档

《浅析如何使用Swagger生成带权限控制的API文档》当涉及到权限控制时,如何生成既安全又详细的API文档就成了一个关键问题,所以这篇文章小编就来和大家好好聊聊如何用Swagger来生成带有... 目录准备工作配置 Swagger权限控制给 API 加上权限注解查看文档注意事项在咱们的开发工作里,API