本文主要是介绍【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
关于
近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)
工具
数据集
方法实现
加载必要的库函数和自定义函数
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as Ffrom torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
def get_sample_image(G, n_noise):"""save sample 100 images"""z = torch.randn(100, n_noise).to(DEVICE)y_hat = G(z).view(100, 28, 28) # (100, 28, 28)result = y_hat.cpu().data.numpy()img = np.zeros([280, 280])for j in range(10):img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)return img
定义判别模型
class Discriminator(nn.Module):"""Convolutional Discriminator for MNIST"""def __init__(self, in_channel=1, num_classes=1):super(Discriminator, self).__init__()self.conv = nn.Sequential(# 28 -> 14nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2),# 14 -> 7nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2),# 7 -> 4nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.LeakyReLU(0.2),nn.AvgPool2d(4),)self.fc = nn.Sequential(# reshape input, 128 -> 1nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x, y=None):y_ = self.conv(x)y_ = y_.view(y_.size(0), -1)y_ = self.fc(y_)return y_
定义生成模型
class Generator(nn.Module):"""Convolutional Generator for MNIST"""def __init__(self, input_size=100, num_classes=784):super(Generator, self).__init__()self.fc = nn.Sequential(nn.Linear(input_size, 4*4*512),nn.ReLU(),)self.conv = nn.Sequential(# input: 4 by 4, output: 7 by 7nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(),# input: 7 by 7, output: 14 by 14nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),nn.BatchNorm2d(128),nn.ReLU(),# input: 14 by 14, output: 28 by 28nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),nn.Tanh(),)def forward(self, x, y=None):x = x.view(x.size(0), -1)y_ = self.fc(x)y_ = y_.view(y_.size(0), 512, 4, 4)y_ = self.conv(y_)return y_
模型超参数定义配置
batch_size = 64criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake
模型训练
for epoch in range(max_epoch):for idx, (images, labels) in enumerate(data_loader):# Training Discriminatorx = images.to(DEVICE)x_outputs = D(x)D_x_loss = criterion(x_outputs, D_labels)z = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))D_z_loss = criterion(z_outputs, D_fakes)D_loss = D_x_loss + D_z_lossD.zero_grad()D_loss.backward()D_opt.step()if step % n_critic == 0:# Training Generatorz = torch.randn(batch_size, n_noise).to(DEVICE)z_outputs = D(G(z))G_loss = criterion(z_outputs, D_labels)D.zero_grad()G.zero_grad()G_loss.backward()G_opt.step()if step % 500 == 0:print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))if step % 1000 == 0:G.eval()img = get_sample_image(G, n_noise)imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')G.train()step += 1
测试生成效果
# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')
模型和状态参量保存
def save_checkpoint(state, file_name='checkpoint.pth.tar'):torch.save(state, file_name)# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')
应用
DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。
这篇关于【图像合成】基于DCGAN典型网络的MNIST字符生成(pytorch)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!