深度卷积生成对抗网络 (DCGAN)

2024-01-12 22:52

本文主要是介绍深度卷积生成对抗网络 (DCGAN),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

深度卷积生成对抗网络 (DCGAN) 是一种生成模型,它使用深度卷积神经网络来生成新数据样本的任务。以下是有关 DCGAN 的一些要点:

建筑:

DCGAN由生成器和鉴别器网络组成。
生成器负责从随机噪声中生成真实的数据样本。
鉴别器试图区分真实数据样本和生成器生成的数据样本。
卷积层:

DCGAN 在生成器和鉴别器中使用卷积层来捕获数据中的空间层次结构和模式。
卷积层有助于学习局部特征,对于图像相关任务至关重要。
批量归一化:

批量归一化通常用于生成器和判别器中,以稳定和加速训练。
它将输入归一化到图层,有助于缓解渐变消失等问题。
激活功能:

通常,ReLU(整流线性单元)激活函数用于中间层的生成器中。
生成器的输出层通常使用 tanh 激活函数来生成介于 -1 和 1 之间的像素值。
发电机输入:

发生器的输入通常是随机噪声(通常从正态分布中采样)。
生成器学会将这种噪声转换为真实的数据样本。
鉴别器输出:

鉴别器的输出是一个概率,指示输入是真实数据样本的可能性。
sigmoid 激活函数通常用于鉴别器的输出层,以生成介于 0 和 1 之间的值。
损失函数:

生成器旨在最小化判别器进行正确分类的概率(最小化 log(1 - D(G(z))),其中 G(z) 是生成的样本)。
鉴别器旨在正确分类真实样本和生成的样本(最小化真实样本的 log(D(x)) 和生成的样本的 log(1 - D(G(z)))。
培训流程:

DCGAN 使用最小-最大博弈进行训练,其中生成器和鉴别器是迭代训练的。
训练过程涉及更新两个网络的权重以提高其性能。
可视化:

在训练过程中,DCGAN产生越来越逼真的数据样本,生成器学习生成多样化和高质量的输出。
应用:

DCGAN广泛用于图像生成任务,包括生成逼真的人脸、物体和场景。
它们还被应用于图像到图像转换和样式转换等任务。

import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import os
# 设置环境变量以避免 OpenMP 问题
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"import torch
from torchvision import datasets
from torchvision import transformstransform=transforms.ToTensor()svhn_train=datasets.SVHN(root='data/',split='train',download=True,transform=transform)batch_size=128
num_workers=0train_loader = torch.utils.data.DataLoader(dataset=svhn_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)# 可视化数据
# Visualize data
dataiter = iter(train_loader)
images, labels = next(dataiter)fig = plt.figure(figsize=(25, 4))
plot_size = 20
for idx in np.arange(plot_size):ax = fig.add_subplot(2, plot_size // 2, idx + 1, xticks=[], yticks=[])ax.imshow(np.transpose(images[idx], (1, 2, 0)))ax.set_title(str(labels[idx].item()))img = images[0]print('Min:', img.min())
print('Max:', img.max())plt.show()# helper scale function
def scale(x, feature_range=(-1, 1)):min, max = feature_rangex = x * (max - min) + minreturn x# scaled range
scaled_img = scale(img)print('Scaled min: ', scaled_img.min())
print('Scaled max: ', scaled_img.max())# 定义模型import torch.nn as nn
import torch.nn.functional as F# helper conv function
def conv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):layers = []conv_layer = nn.Conv2d(in_channels, out_channels,kernel_size, stride, padding, bias=False)layers.append(conv_layer)if batch_norm:layers.append(nn.BatchNorm2d(out_channels))return nn.Sequential(*layers)class Discriminator(nn.Module):def __init__(self, conv_dim=32):super(Discriminator, self).__init__()self.conv_dim = conv_dimself.conv1 = conv(3, conv_dim, 4, batch_norm=False)self.conv2 = conv(conv_dim, conv_dim * 2, 4)self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)self.fc = nn.Linear(conv_dim * 4 * 4 * 4, 1)def forward(self, x):out = F.leaky_relu(self.conv1(x), 0.2)out = F.leaky_relu(self.conv2(out), 0.2)out = F.leaky_relu(self.conv3(out), 0.2)out = out.view(-1, self.conv_dim * 4 * 4 * 4)out = self.fc(out)return outdef deconv(in_channels, out_channels, kernel_size, stride=2, padding=1, batch_norm=True):layers = []transpose_conv_layer = nn.ConvTranspose2d(in_channels, out_channels,kernel_size, stride, padding, bias=False)layers.append(transpose_conv_layer)if batch_norm:layers.append(nn.BatchNorm2d(out_channels))return nn.Sequential(*layers)class Generator(nn.Module):def __init__(self, z_size, conv_dim=32):super(Generator, self).__init__()self.conv_dim = conv_dimself.fc = nn.Linear(z_size, conv_dim * 4 * 4 * 4)self.t_conv1 = deconv(conv_dim * 4, conv_dim * 2, 4)self.t_conv2 = deconv(conv_dim * 2, conv_dim, 4)self.t_conv3 = deconv(conv_dim, 3, 4, batch_norm=False)def forward(self, x):out = self.fc(x)out = out.view(-1, self.conv_dim * 4, 4, 4)out = F.relu(self.t_conv1(out))out = F.relu(self.t_conv2(out))out = self.t_conv3(out)out = F.tanh(out)return outconv_dim = 32
z_size = 100D = Discriminator(conv_dim)
G = Generator(z_size=z_size, conv_dim=conv_dim)print(D)
print()
print(G)train_on_gpu = torch.cuda.is_available()if train_on_gpu:G.cuda()D.cuda()print('GPU available for training. Models moved to GPU')
else:print('Training on CPU.')def real_loss(D_out, smooth=False):batch_size = D_out.size(0)# label smoothingif smooth:# smooth, real labels = 0.9labels = torch.ones(batch_size)*0.9else:labels = torch.ones(batch_size) # real labels = 1# move labels to GPU if availableif train_on_gpu:labels = labels.cuda()# binary cross entropy with logits losscriterion = nn.BCEWithLogitsLoss()# calculate lossloss = criterion(D_out.squeeze(), labels)return lossdef fake_loss(D_out):batch_size = D_out.size(0)labels = torch.zeros(batch_size) # fake labels = 0if train_on_gpu:labels = labels.cuda()criterion = nn.BCEWithLogitsLoss()# calculate lossloss = criterion(D_out.squeeze(), labels)return lossimport torch.optim as optim# params
lr = 0.0002
beta1=0.5
beta2=0.999d_optimizer = optim.Adam(D.parameters(), lr, [beta1, beta2])
g_optimizer = optim.Adam(G.parameters(), lr, [beta1, beta2])import pickle as pklnum_epochs = 50samples = []
losses = []print_every = 300sample_size = 16
fixed_z = np.random.uniform(-1, 1, size=(sample_size, z_size))
fixed_z = torch.from_numpy(fixed_z).float()for epoch in range(num_epochs):for batch_i, (real_images, _) in enumerate(train_loader):batch_size = real_images.size(0)real_images = scale(real_images)d_optimizer.zero_grad()if train_on_gpu:real_images = real_images.cuda()D_real = D(real_images)d_real_loss = real_loss(D_real)z = np.random.uniform(-1, 1, size=(batch_size, z_size))z = torch.from_numpy(z).float()if train_on_gpu:z = z.cuda()fake_images = G(z)D_fake = D(fake_images)d_fake_loss = fake_loss(D_fake)d_loss = d_real_loss + d_fake_lossd_loss.backward()d_optimizer.step()g_optimizer.zero_grad()z = np.random.uniform(-1, 1, size=(batch_size, z_size))z = torch.from_numpy(z).float()if train_on_gpu:z = z.cuda()fake_images = G(z)D_fake = D(fake_images)g_loss = real_loss(D_fake)g_loss.backward()g_optimizer.step()if batch_i % print_every == 0:losses.append((d_loss.item(), g_loss.item()))print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(epoch + 1, num_epochs, d_loss.item(), g_loss.item()))G.eval()if train_on_gpu:fixed_z = fixed_z.cuda()samples_z = G(fixed_z)samples.append(samples_z)G.train()with open('train_samples.pkl', 'wb') as f:pkl.dump(samples, f)fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator', alpha=0.5)
plt.plot(losses.T[1], label='Generator', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.show()
def view_samples(epoch, samples):fig, axes = plt.subplots(figsize=(16,4), nrows=2, ncols=8, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples[epoch]):img = img.detach().cpu().numpy()img = np.transpose(img, (1, 2, 0))img = ((img +1)*255 / (2)).astype(np.uint8)ax.xaxis.set_visible(False)ax.yaxis.set_visible(False)im = ax.imshow(img.reshape((32,32,3)))
_ = view_samples(-1, samples)

这篇关于深度卷积生成对抗网络 (DCGAN)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/599424

相关文章

MybatisGenerator文件生成不出对应文件的问题

《MybatisGenerator文件生成不出对应文件的问题》本文介绍了使用MybatisGenerator生成文件时遇到的问题及解决方法,主要步骤包括检查目标表是否存在、是否能连接到数据库、配置生成... 目录MyBATisGenerator 文件生成不出对应文件先在项目结构里引入“targetProje

Python使用qrcode库实现生成二维码的操作指南

《Python使用qrcode库实现生成二维码的操作指南》二维码是一种广泛使用的二维条码,因其高效的数据存储能力和易于扫描的特点,广泛应用于支付、身份验证、营销推广等领域,Pythonqrcode库是... 目录一、安装 python qrcode 库二、基本使用方法1. 生成简单二维码2. 生成带 Log

五大特性引领创新! 深度操作系统 deepin 25 Preview预览版发布

《五大特性引领创新!深度操作系统deepin25Preview预览版发布》今日,深度操作系统正式推出deepin25Preview版本,该版本集成了五大核心特性:磐石系统、全新DDE、Tr... 深度操作系统今日发布了 deepin 25 Preview,新版本囊括五大特性:磐石系统、全新 DDE、Tree

SSID究竟是什么? WiFi网络名称及工作方式解析

《SSID究竟是什么?WiFi网络名称及工作方式解析》SID可以看作是无线网络的名称,类似于有线网络中的网络名称或者路由器的名称,在无线网络中,设备通过SSID来识别和连接到特定的无线网络... 当提到 Wi-Fi 网络时,就避不开「SSID」这个术语。简单来说,SSID 就是 Wi-Fi 网络的名称。比如

Python使用Pandas库将Excel数据叠加生成新DataFrame的操作指南

《Python使用Pandas库将Excel数据叠加生成新DataFrame的操作指南》在日常数据处理工作中,我们经常需要将不同Excel文档中的数据整合到一个新的DataFrame中,以便进行进一步... 目录一、准备工作二、读取Excel文件三、数据叠加四、处理重复数据(可选)五、保存新DataFram

SpringBoot生成和操作PDF的代码详解

《SpringBoot生成和操作PDF的代码详解》本文主要介绍了在SpringBoot项目下,通过代码和操作步骤,详细的介绍了如何操作PDF,希望可以帮助到准备通过JAVA操作PDF的你,项目框架用的... 目录本文简介PDF文件简介代码实现PDF操作基于PDF模板生成,并下载完全基于代码生成,并保存合并P

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

Node.js 中 http 模块的深度剖析与实战应用小结

《Node.js中http模块的深度剖析与实战应用小结》本文详细介绍了Node.js中的http模块,从创建HTTP服务器、处理请求与响应,到获取请求参数,每个环节都通过代码示例进行解析,旨在帮... 目录Node.js 中 http 模块的深度剖析与实战应用一、引言二、创建 HTTP 服务器:基石搭建(一

详解Java中如何使用JFreeChart生成甘特图

《详解Java中如何使用JFreeChart生成甘特图》甘特图是一种流行的项目管理工具,用于显示项目的进度和任务分配,在Java开发中,JFreeChart是一个强大的开源图表库,能够生成各种类型的图... 目录引言一、JFreeChart简介二、准备工作三、创建甘特图1. 定义数据集2. 创建甘特图3.

AI一键生成 PPT

AI一键生成 PPT 操作步骤 作为一名打工人,是不是经常需要制作各种PPT来分享我的生活和想法。但是,你们知道,有时候灵感来了,时间却不够用了!😩直到我发现了Kimi AI——一个能够自动生成PPT的神奇助手!🌟 什么是Kimi? 一款月之暗面科技有限公司开发的AI办公工具,帮助用户快速生成高质量的演示文稿。 无论你是职场人士、学生还是教师,Kimi都能够为你的办公文