本文主要是介绍G2 - 人脸图像生成(DCGAN),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目录
- 理论知识
- DCGAN原理
- 模型结构
- 逻辑结构
- 物理结构
- 模型实现
- 前期准备
- 1. 导入第三方库
- 2. 修改随机种子(相同的随机种子,第i次随机的结果是固定的)
- 3. 设置超参数
- 4. 导入数据
- 模型定义
- 1. 编写权重初始化函数
- 2. 定义生成器
- 3. 创建生成器
- 4. 定义判别器
- 5. 创建判别器
- 6. 定义训练参数
- 7. 模型训练
- 模型效果
- 打印训练过程图
- 打印训练过程保存的阶段图片
- 对比
- 总结与心得体会
理论知识
DCGAN(Deep Convolutional Generative Adversarial Networks,深度卷积生成对抗网络)是结合了卷积神经网络(Convlutional Nerual Networks,简称CNN)和生成对抗网络(Generative Adversarial Networks,简称GAN)的思想,用来生成图像。
DCGAN原理
DCGAN将卷积运算引入生成式模型,用来做无监督训练,可以使用卷积网络强大的特征提取能力来提高生成网络的学习效果。DCGAN具有以下特点:
- 判别器模型使用卷积步长取代空间池化,生成器模型中使用反卷积操作扩大数据维度
- 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其他层都使用了Batch Normalization(稍后的实验将证明它有多么的重要),Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
- 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
- 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上全使用的LeakyReLU激活函数。
模型结构
逻辑结构
如图中所示,DCGAN模型主要包括了一个生成网络G和一个判别网络D
生成网络G负责生成图像,它接受一个随机的噪声z
,通过该噪声来生成图像,将生成的图像记为G(z)
。
判别网络D负责判断一张图像是否为真实的,它的输入为图像x
,输出为D(x)
表示x
为真实图像的概率。
实际上,判别网络D是对数据的来源进行判别:究竟这个数据是来自于真实的分布 P d ( x ) P_{d(x)} Pd(x)(判别为1),还是来自于一个生成网络G所产生的数据分布 P g ( z ) P_{g(z)} Pg(z)(判别为0)。所以在整个训练的过程中,生成网络G的目标是生成可以以假乱真的图像G(z)
,当判别网络D无法区分(也就是D(G(z)) = 0.5
时),便以为生成网络已经收敛,贴近了真实图像的分布。这时的生成网络G生成的图像,还可以用来生成图像的数据集的扩充数据集。
物理结构
在物理上生成网络有4个转置卷积层,对应的判别网络有4个卷积层。其中4*4*512
代表这一层共有512个大小为4*4
的特征图。BN和ReLU分别在卷积层之前和卷积层之后使用。Tanh和LeakyReLU分别表示双正切激活函数和弱修正线性激活函数。
模型实现
前期准备
环境:
Pytorch 2.2.0+cu121
GTX4090
Python 3.11
Jupyter lab
1. 导入第三方库
import torch, random, os
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
2. 修改随机种子(相同的随机种子,第i次随机的结果是固定的)
manual_seed = 999
print("Random seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
torch.use_deterministic_algorithms(True)
输出
Random seed: 999
3. 设置超参数
# 数据的路径
data_root = 'GAN-Data/'
# 训练批次大小
batch_size = 128
# 图像尺寸
image_size = 64
# z的向量大小,生成器的输入尺寸
nz = 100
# 生成器的特征图大小
ngf = 100
# 判别器的特征图大小
ndf = 100
# 训练的总轮数
num_epochs = 50
# 学习率
lr = 0.0002
# Adam优化器超参数
beta1 = 0.5
4. 导入数据
# 创建数据集
dataset = datasets.ImageFolder(root=data_root,transform=transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),])
)# 创建数据加载器
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=5 # 使用多线程来处理数据
)# 选择要在哪个设备上运行代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('使用的设备是:', device)# 查看一些训练的图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Training Images')
plt.imshow(np.transpose(utils.make_grid(real_batch[0][:24], padding=2, normalize=True),(1, 2, 0)))
模型定义
1. 编写权重初始化函数
def weights_init(m):class_name = m.__class__.__name__if class_name.find('Conv') != -1:# 类名中有Conv,卷积或反卷积层,使用均值为0标准差为0.02的正态分布来初始化nn.init.normal_(m.weight.data, 0.0, 0.02)elif class_name.find('BatchNorm') != -1:# 类名中有BatchNorm 批归一化层,使用均值为1,标准差为0.02的正态分布初始化权重nn.init.normal_(m.weight.data, 1.0, 0.02)# 偏置初始化为0nn.init.constant_(m.bias.data, 0)
2. 定义生成器
class Generator(nn.Module):def __init__(self):super().__init__()self.main = nn.Sequential(# 转置卷积nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),# 批归一化,用于加速收敛和稳定训练过程nn.BatchNorm2d(ngf*8),# 激活函数nn.ReLU(True),# 第一个单元处理完,模型输出(ngf*8)x4x4nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf*4),nn.ReLU(True),# 第二个单元处理完,模型输出(ngf*4)x8x8nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf*2),nn.ReLU(True),# 第三个单元处理完,模型输出(ngf*2)*16*16nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 第4个单元处理完,模型输出(ngf)*32*32nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),nn.Tanh()# 最后模型输出 3*64*64)def forward(self, x):return self.main(x)
3. 创建生成器
# 创建生成器模型
netG = Generator().to(device)
# 初始化生成器模型
netG.apply(weights_init)
# 打印生成器结构
print(netG)
4. 定义判别器
class Discriminator(nn.Module):def __init__(self):super().__init__()self.main = nn.Sequential(# 判别器的输入是生成器的输出 3x64x64-> ndfx32x32nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 第二层卷积 ndfx32x32 -> (ndf*2)x16x16nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2, inplace=True),# 第三层卷积 (ndf*2)x16x16->(ndf*4)x8x8nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2, inplace=True),# 第四层卷积 (ndf*4)x8x8 -> (ndf*8)x4x4nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2, inplace=True),# 最后一层卷积 (ndf*8)x4x4 -> 1x1x1nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, inputs):return self.main(inputs)
5. 创建判别器
# 创建判别器模型
netD = Discriminator().to(device)
# 初始化判别器参数
netD.apply(weights_init)
# 打印判别器模型结构
print(netD)
6. 定义训练参数
# 定义损失函数
criterion = nn.BCELoss()
# 创建用于对比的固定随机输入
fixed_noise = torch.randn(8*8, nz, 1, 1, device=device)
# 标签
real_label = 1.
fake_label = 0.
# 创建优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
7. 模型训练
# 存储生成的图像
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training")
for epoch in range(num_epochs):for i, data in enumerate(dataloader):# 判别器部分# 数据中的图像和标签real_images = data[0].to(device)real_images_label = torch.full((real_images.size(0), ), real_label, dtype=torch.float, device=device)# 梯度清零optimizerD.zero_grad()optimizerG.zero_grad()# 正向计算real_D_output = netD(real_images).view(-1)real_D_loss = criterion(real_D_output, real_images_label)# 反向传播real_D_loss.backward()# 保存平均误差D_x = real_D_output.mean().item()# 生成图像和标签noise = torch.randn(real_images.size(0), nz, 1, 1, device=device)synthesis_images = netG(noise)synthesis_images_label = torch.full((real_images.size(0), ), fake_label, dtype=torch.float, device=device)# 正向计算synthesis_D_output = netD(synthesis_images).view(-1)synthesis_D_loss = criterion(synthesis_D_output, synthesis_images_label)# 反向传播synthesis_D_loss.backward(retain_graph=True)# 保存平均误差D_G_z1 = synthesis_D_output.mean().item()# 更新判别器参数optimizerD.step()# 判别器总误差errD = real_D_loss + synthesis_D_loss# 生成器部分# 期望标签expect_labels = torch.full((real_images.size(0),), real_label, dtype=torch.float, device=device)# 正向计算expect_output = netD(synthesis_images).view(-1)expect_loss = criterion(expect_output, expect_labels)# 反向传播expect_loss.backward()# 保存平均误差D_G_z2 = expect_output.mean().item()# 更新生成器参数optimizerG.step()# 生成器总误差errG = expect_lossif i % 400 == 0:print(f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]"f"\tLoss_D: {errD.item():.4f}\tLoss_G: {errG.item():.4f}"f"\tD(x): {D_x:.4f}\tD(G(z)): {D_G_z1:.4} / {D_G_z2:.4f}")G_losses.append(errG.item())D_losses.append(errD.item())if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(utils.make_grid(fake[:64], padding=2, normalize=True))iters += 1
训练过程
Starting Training
[0/50][0/36] Loss_D: 1.7595 Loss_G: 4.2032 D(x): 0.4410 D(G(z)): 0.5005 / 0.0225
[1/50][0/36] Loss_D: 0.2791 Loss_G: 22.5919 D(x): 0.8837 D(G(z)): 0.0000 / 0.0000
[2/50][0/36] Loss_D: 0.2593 Loss_G: 10.4404 D(x): 0.8790 D(G(z)): 0.0030 / 0.0004
[3/50][0/36] Loss_D: 0.1284 Loss_G: 3.6274 D(x): 0.9703 D(G(z)): 0.0796 / 0.0575
[4/50][0/36] Loss_D: 0.5483 Loss_G: 6.2631 D(x): 0.9186 D(G(z)): 0.2988 / 0.0050
[5/50][0/36] Loss_D: 1.4979 Loss_G: 4.8728 D(x): 0.3821 D(G(z)): 0.0025 / 0.0144
[6/50][0/36] Loss_D: 0.8146 Loss_G: 3.3497 D(x): 0.6144 D(G(z)): 0.0235 / 0.0598
[7/50][0/36] Loss_D: 0.6614 Loss_G: 5.6657 D(x): 0.9099 D(G(z)): 0.3836 / 0.0059
[8/50][0/36] Loss_D: 0.1566 Loss_G: 5.6704 D(x): 0.8978 D(G(z)): 0.0379 / 0.0078
[9/50][0/36] Loss_D: 0.3516 Loss_G: 3.4337 D(x): 0.8366 D(G(z)): 0.1212 / 0.0602
[10/50][0/36] Loss_D: 0.5276 Loss_G: 4.6958 D(x): 0.8007 D(G(z)): 0.1975 / 0.0195
[11/50][0/36] Loss_D: 0.3213 Loss_G: 3.3171 D(x): 0.8546 D(G(z)): 0.1102 / 0.0554
[12/50][0/36] Loss_D: 0.4608 Loss_G: 3.9185 D(x): 0.8921 D(G(z)): 0.2136 / 0.0424
[13/50][0/36] Loss_D: 1.2687 Loss_G: 7.7746 D(x): 0.9671 D(G(z)): 0.5904 / 0.0022
[14/50][0/36] Loss_D: 0.3971 Loss_G: 4.1165 D(x): 0.7512 D(G(z)): 0.0442 / 0.0272
[15/50][0/36] Loss_D: 1.0332 Loss_G: 8.4574 D(x): 0.9400 D(G(z)): 0.5136 / 0.0010
[16/50][0/36] Loss_D: 0.3976 Loss_G: 3.7233 D(x): 0.8225 D(G(z)): 0.1058 / 0.0404
[17/50][0/36] Loss_D: 0.3433 Loss_G: 4.0558 D(x): 0.8422 D(G(z)): 0.1229 / 0.0302
[18/50][0/36] Loss_D: 0.3605 Loss_G: 3.4592 D(x): 0.8398 D(G(z)): 0.1089 / 0.0647
[19/50][0/36] Loss_D: 0.3316 Loss_G: 6.4387 D(x): 0.7779 D(G(z)): 0.0076 / 0.0085
[20/50][0/36] Loss_D: 0.3899 Loss_G: 3.7591 D(x): 0.8202 D(G(z)): 0.0973 / 0.0433
[21/50][0/36] Loss_D: 0.3571 Loss_G: 5.9245 D(x): 0.9634 D(G(z)): 0.2298 / 0.0065
[22/50][0/36] Loss_D: 0.2840 Loss_G: 4.5101 D(x): 0.8919 D(G(z)): 0.1315 / 0.0184
[23/50][0/36] Loss_D: 0.4924 Loss_G: 3.1086 D(x): 0.7023 D(G(z)): 0.0199 / 0.0716
[24/50][0/36] Loss_D: 0.2107 Loss_G: 4.2722 D(x): 0.9547 D(G(z)): 0.1357 / 0.0234
[25/50][0/36] Loss_D: 0.1484 Loss_G: 4.5043 D(x): 0.9319 D(G(z)): 0.0651 / 0.0250
[26/50][0/36] Loss_D: 1.2124 Loss_G: 3.0712 D(x): 0.4369 D(G(z)): 0.0214 / 0.0948
[27/50][0/36] Loss_D: 0.3019 Loss_G: 4.6950 D(x): 0.9525 D(G(z)): 0.1980 / 0.0156
[28/50][0/36] Loss_D: 0.7236 Loss_G: 7.5887 D(x): 0.9555 D(G(z)): 0.4410 / 0.0010
[29/50][0/36] Loss_D: 0.3441 Loss_G: 3.3680 D(x): 0.8283 D(G(z)): 0.0998 / 0.0613
[30/50][0/36] Loss_D: 0.2165 Loss_G: 5.1199 D(x): 0.8957 D(G(z)): 0.0762 / 0.0129
[31/50][0/36] Loss_D: 0.3176 Loss_G: 4.1078 D(x): 0.8947 D(G(z)): 0.1468 / 0.0326
[32/50][0/36] Loss_D: 0.3671 Loss_G: 3.9614 D(x): 0.7650 D(G(z)): 0.0373 / 0.0386
[33/50][0/36] Loss_D: 0.4350 Loss_G: 2.8999 D(x): 0.7714 D(G(z)): 0.0983 / 0.0861
[34/50][0/36] Loss_D: 0.7197 Loss_G: 2.4081 D(x): 0.6495 D(G(z)): 0.1329 / 0.1539
[35/50][0/36] Loss_D: 0.8072 Loss_G: 2.5399 D(x): 0.5736 D(G(z)): 0.0145 / 0.1650
[36/50][0/36] Loss_D: 0.5928 Loss_G: 3.1043 D(x): 0.6356 D(G(z)): 0.0245 / 0.0795
[37/50][0/36] Loss_D: 0.9726 Loss_G: 1.7994 D(x): 0.5331 D(G(z)): 0.0338 / 0.2744
[38/50][0/36] Loss_D: 0.2589 Loss_G: 3.7818 D(x): 0.8953 D(G(z)): 0.1114 / 0.0394
[39/50][0/36] Loss_D: 0.7775 Loss_G: 6.9969 D(x): 0.9280 D(G(z)): 0.4414 / 0.0028
[40/50][0/36] Loss_D: 0.3364 Loss_G: 3.8214 D(x): 0.8090 D(G(z)): 0.0860 / 0.0409
[41/50][0/36] Loss_D: 1.3642 Loss_G: 8.2162 D(x): 0.9447 D(G(z)): 0.6119 / 0.0010
[42/50][0/36] Loss_D: 0.7877 Loss_G: 3.8933 D(x): 0.8114 D(G(z)): 0.3348 / 0.0398
[43/50][0/36] Loss_D: 0.3706 Loss_G: 4.6370 D(x): 0.8970 D(G(z)): 0.1947 / 0.0169
[44/50][0/36] Loss_D: 0.7339 Loss_G: 6.4576 D(x): 0.9739 D(G(z)): 0.4491 / 0.0032
[45/50][0/36] Loss_D: 1.0858 Loss_G: 0.7336 D(x): 0.4564 D(G(z)): 0.0115 / 0.5793
[46/50][0/36] Loss_D: 0.2643 Loss_G: 4.0960 D(x): 0.8862 D(G(z)): 0.1110 / 0.0266
[47/50][0/36] Loss_D: 0.3914 Loss_G: 3.3387 D(x): 0.7496 D(G(z)): 0.0531 / 0.0558
[48/50][0/36] Loss_D: 0.5433 Loss_G: 4.0881 D(x): 0.7723 D(G(z)): 0.1693 / 0.0369
[49/50][0/36] Loss_D: 0.3666 Loss_G: 4.0757 D(x): 0.8962 D(G(z)): 0.2012 / 0.0263
模型效果
打印训练过程图
plt.figure(figsize=(10, 5))
plt.title('Generator and Discriminator Loss During Traning')
plt.plot(G_losses, label='G')
plt.plot(D_losses, label='D')
plt.xlabel('iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()
打印训练过程保存的阶段图片
这里用到了只能在jupyter notebook中使用的函数HTML
fig = plt.figure(figsize=(8, 8))
plt.axis('off')
imgs = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, imgs, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
未训练时
训练了1/4
训练了1/2
最终效果
与原数据集中的图像进行对比,在颜色和局部的细节方面,的确已经非常的相似,不过人眼还是很容易区分出的。
对比
# 获取一批真实图像
real_batch = next(iter(dataloader))plt.figure(figsize=(15, 15))
plt.subplot(121)
plt.axis('off')
plt.title('Real Images')
plt.imshow(np.transpose(utils.make_grid(real_batch[0][:64], padding=5, normalize=True), (1, 2, 0))plt.subplot(122)
plt.axis('off')
plt.title('Fake Images')
plt.imshow(np.transpose(img_list[-1], (1, 2, 0))
总结与心得体会
在模型实现的过程中,我一开始忽略了鉴别器卷积层之间的BatchNorm层,在其它参数不变的情况下,模型的训练效果非常差,没有学习到原始图像的颜色,只学习到了一些人脸的轮廓。
通过对结果的分析,我认为生成器和判别器中的模型层数太小,导致卷积只学习到了浅层的特征,缺乏深层的抽象特征,体现在结果中就是,放大看图像的部分区域没有问题,但是整体来看生成的人脸都有些问题,不同的部分之间缺少一致性,看起来不太协调。
这篇关于G2 - 人脸图像生成(DCGAN)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!