PyTorch深度学习实战(37)——CycleGAN详解与实现

2024-02-22 12:12

本文主要是介绍PyTorch深度学习实战(37)——CycleGAN详解与实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyTorch深度学习实战(37)——CycleGAN详解与实现

    • 0. 前言
    • 1. CycleGAN 基本原理
    • 2. CycleGAN 模型分析
    • 3. 实现 CycleGAN
    • 小结
    • 系列链接

0. 前言

CycleGAN 是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成另一种风格,而无需为每一对输入-输出图像配对训练数据。CycleGAN 的核心思想是利用两个生成器和两个判别器,它们共同学习两个域之间的映射关系。例如,将马的图像转换成斑马的图像,或者将夏天的风景转换成冬天的风景。在本节中,我们将学习 CycleGAN 的基本原理,并实现该模型用于将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像。

1. CycleGAN 基本原理

CycleGAN 是一种无需配对的图像转换技术,它可以将一个图像域中的图像转换为另一个图像域中的图像,而不需要匹配这两个域中的图像。它使用两个生成器和两个判别器,其中一个生成器将一个域中的图像转换为另一个域中的图像,而第二个生成器将其转换回来。这个过程被称为循环一致性,转换过程是可逆的。
CycleGAN 可以用于执行从一个类别到另一个类别的图像转换,而无需提供相匹配的输入-输出图像对来训练模型,只需要在两个不同的文件夹中提供这两个类别的图像。在本节中,我们将学习如何训练 CycleGAN 将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像,CycleGAN 中的 Cycle 是指将图像从一个类别转换到另一个类别,然后再转换回原始类别的过程。
CycleGAN 中,需要使用三种不同的损失值:

  • 鉴别器损失:用于区分真实图像和伪造图像
  • 循环一致性损失:由于 CycleGAN 使用了两个生成器,因此需要确保转换是可逆的,循环一致性损失通过将转换过的图像再次传递到原始的生成器中,并将生成的图像与原始图像进行比较来实现
  • 恒等损失 (Identity loss):确保生成器在不进行转换的情况下仍然能够生成与原始图像相似的图像,通过将原始图像传递到生成器中,并计算生成图像与原始图像之间的差异

2. CycleGAN 模型分析

CycleGAN 模型构建策略如下:

  1. 导入数据集并进行预处理
  2. 定义 UNet 架构用于构建生成器和判别器网络
  3. 定义两个生成器:
    • G_AB:将类别 A 图像转换为类别 B 图像的生成器
    • G_BA:将类别 B 图像转换为类别 A 图像的生成器
  4. 定义恒等损失:
    • 如果将一张橘子的图像输入到橙子生成器,理想情况下,如果生成器完全理解橙子的所有信息,它不应该改变图像,而应该“生成”完全相同的图像,据此,我们可以创建一个恒等变换
    • 当类别 A (real_A) 的图像通过 G_BA 并与 real_A 进行比较时,恒等损失应该是最小的
    • 当类别 B (real_B) 的图像通过 G_AB 并与 real_B 进行比较时,恒等损失应该是最小的
  5. 定义GAN损失:
    • real_Afake_A 的判别器和生成器损失(当 real_B 图像通过 G_BA 时得到 fake_A)
    • real_Bfake_B 的判别器和生成器损失(当 real_A 图像通过 G_AB 时得到 fake_B)
  6. 定义循环一致性损失:
    • 一张苹果图像需要通过橙子生成网络进行转换,生成伪造的橘子图像,然后再通过苹果生成网络将伪造的橙子图像转换回苹果图像
    • fake_Breal_A 通过 G_AB 时的输出,当 fake_B 通过 G_BA 时应该重新生成 real_A
    • fake_Areal_B 通过 G_BA 时的输出,当 fake_A 通过 G_AB 时应该重新生成 real_B
  7. 优化三个损失函数的加权和

3. 实现 CycleGAN

接下来,我们使用 PyTorch 实现 CycleGAN 模型,用以将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像。

(1) 导入相关数据集和库。

首先下载并解压数据集,可以自行构建数据集,也可以下载本文所用数据集,下载地址:https://pan.baidu.com/s/1iTOt2NsUQ1a3taUHjvkjfA,提取码:iuqf

可视化示例图像如下:
数据集

与 Pix2Pix 训练数据集不同,苹果和橙色图像之间不存在一一对应的关系。

导入所需的库:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
import cv2
import random
from glob import glob
from PIL import Image
import itertools
from torchvision import transforms
device = "cuda" if torch.cuda.is_available() else "cpu"

(2) 定义图像转换管道 transform

IMAGE_SIZE = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([transforms.Resize(int(IMAGE_SIZE*1.33)),transforms.RandomCrop((IMAGE_SIZE,IMAGE_SIZE)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

(3) 定义数据集类 CycleGANDataset,以苹果图像 apple 和橙子图像 orange 文件夹为输入,提供批数据:

class CycleGANDataset(Dataset):def __init__(self, apples, oranges):self.apples = glob(apples)self.oranges = glob(oranges)def __getitem__(self, ix):apple = self.apples[ix % len(self.apples)]orange = random.choice(self.oranges)apple = Image.open(apple).convert('RGB')orange = Image.open(orange).convert('RGB')return apple, orangedef __len__(self):return max(len(self.apples), len(self.oranges))def choose(self):return self[random.randint(len(self))]def collate_fn(self, batch):srcs, trgs = list(zip(*batch))srcs = torch.cat([transform(img)[None] for img in srcs], 0).to(device).float()trgs = torch.cat([transform(img)[None] for img in trgs], 0).to(device).float()return srcs.to(device), trgs.to(device)

(4) 定义训练、验证数据集和数据加载器:

trn_ds = CycleGANDataset('apples_oranges/apples_train/*.jpg', 'apples_oranges/oranges_train/*.jpg')
val_ds = CycleGANDataset('apples_oranges/apples_test/*.jpg', 'apples_oranges/oranges_test/*.jpg')trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=5, shuffle=True, collate_fn=val_ds.collate_fn)

(5) 定义网络的权重初始化方法 weights_init_normal

def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, "bias") and m.bias is not None:torch.nn.init.constant_(m.bias.data, 0.0)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)

(6) 定义残差块 ResidualBlock

class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),)def forward(self, x):return x + self.block(x)

(7) 定义生成器 GeneratorResNet

class GeneratorResNet(nn.Module):def __init__(self, num_residual_blocks=9):super(GeneratorResNet, self).__init__()out_features = 64channels = 3model = [nn.ReflectionPad2d(3),nn.Conv2d(channels, out_features, 7),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Downsamplingfor _ in range(2):out_features *= 2model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Residual blocksfor _ in range(num_residual_blocks):model += [ResidualBlock(out_features)]# Upsamplingfor _ in range(2):out_features //= 2model += [nn.Upsample(scale_factor=2),nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Output layermodel += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]self.model = nn.Sequential(*model)self.apply(weights_init_normal)def forward(self, x):return self.model(x)

(8) 定义判别器 Discriminator

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()channels, height, width = 3, IMAGE_SIZE, IMAGE_SIZEdef discriminator_block(in_filters, out_filters, normalize=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalize:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(channels, 64, normalize=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1))self.apply(weights_init_normal)def forward(self, img):return self.model(img)

(9) 定义图像样本生成函数 generate_sample

@torch.no_grad()
def generate_sample(G_AB, G_BA):data = next(iter(val_dl))G_AB.eval()G_BA.eval()real_A, real_B = datafake_B = G_AB(real_A)fake_A = G_BA(real_B)# Arange images along x-axisreal_A = make_grid(real_A, nrow=5, normalize=True)real_B = make_grid(real_B, nrow=5, normalize=True)fake_A = make_grid(fake_A, nrow=5, normalize=True)fake_B = make_grid(fake_B, nrow=5, normalize=True)# Arange images along y-axisimage_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)plt.imshow(image_grid.detach().cpu().permute(1,2,0).numpy())plt.show()

(10) 定义生成器训练函数 generator_train_step

该函数将两个生成器( G_ABG_BA)、优化器和两个类别的真实图像( real_Areal_B )作为输入:

def generator_train_step(Gs, optimizer, real_A, real_B, D_A, D_B, criterion_identity, criterion_cycle, criterion_GAN, lambda_cyc, lambda_id):

指定生成器:

    G_AB, G_BA = Gs

将优化器的梯度设置为零:

    optimizer.zero_grad()

计算类别 A (苹果)和类别 B (橙子)图像的恒等损失 (loss_identity):

    loss_id_A = criterion_identity(G_BA(real_A), real_A)loss_id_B = criterion_identity(G_AB(real_B), real_B)loss_identity = (loss_id_A + loss_id_B) / 2

计算图像通过生成器时的 GAN 损失,此时生成的图像应尽可能接近另一类别,使用 np.ones 作为训练生成器的判别网络目标输出,因为我们将生成的伪造图像传递给相同类别的判别器:

    fake_B = G_AB(real_A)loss_GAN_AB = criterion_GAN(D_B(fake_B), torch.Tensor(np.ones((len(real_A), 1, 16, 16))).to(device))fake_A = G_BA(real_B)loss_GAN_BA = criterion_GAN(D_A(fake_A), torch.Tensor(np.ones((len(real_A), 1, 16, 16))).to(device))loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

计算循环一致性损失。假设,一张苹果图像被橙子生成器转换为一张伪造橙子图像,然后伪造橙子图像通过苹果生成器转换回一张苹果图像,理想情况下,经过该过程后应该返回原始图像,即循环一致性损失应该为零:

    recov_A = G_BA(fake_B)loss_cycle_A = criterion_cycle(recov_A, real_A)recov_B = G_AB(fake_A)loss_cycle_B = criterion_cycle(recov_B, real_B)loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

计算总损失并执行反向传播:

    loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identityloss_G.backward()optimizer.step()return loss_G, loss_identity, loss_GAN, loss_cycle, loss_G, fake_A, fake_B

(11) 定义判别器训练函数 discriminator_train_step

def discriminator_train_step(D, real_data, fake_data, optimizer, criterion_GAN):optimizer.zero_grad()loss_real = criterion_GAN(D(real_data), torch.Tensor(np.ones((len(real_data), 1, 16, 16))).to(device))loss_fake = criterion_GAN(D(fake_data.detach()), torch.Tensor(np.zeros((len(real_data), 1, 16, 16))).to(device))loss_D = (loss_real + loss_fake) / 2loss_D.backward()optimizer.step()return loss_D

(12) 定义生成器、判别器对象、优化器和损失函数:

G_AB = GeneratorResNet().to(device)
G_BA = GeneratorResNet().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))lambda_cyc, lambda_id = 10.0, 5.0

(13) 训练网络:

n_epochs = 50
# log = Report(n_epochs)
loss_D_epochs = []
loss_G_epochs = []
loss_GAN_epochs = []
loss_cycle_epochs = []
loss_identity_epochs = []
for epoch in range(n_epochs):N = len(trn_dl)loss_D_items = []loss_G_items = []loss_GAN_items = []loss_cycle_items = []loss_identity_items = []for bx, batch in enumerate(trn_dl):real_A, real_B = batchloss_G, loss_identity, loss_GAN, loss_cycle, loss_G, fake_A, fake_B = generator_train_step((G_AB,G_BA), optimizer_G, real_A, real_B, D_A, D_B, criterion_identity, criterion_cycle, criterion_GAN, lambda_cyc, lambda_id)loss_D_A = discriminator_train_step(D_A, real_A, fake_A, optimizer_D_A, criterion_GAN)loss_D_B = discriminator_train_step(D_B, real_B, fake_B, optimizer_D_B, criterion_GAN)loss_D = (loss_D_A + loss_D_B) / 2loss_D_items.append(loss_D.item())loss_G_items.append(loss_G.item())loss_GAN_items.append(loss_GAN.item())loss_cycle_items.append(loss_cycle.item())loss_identity_items.append(loss_identity.item())loss_D_epochs.append(np.average(loss_D_items))loss_G_epochs.append(np.average(loss_G_items))loss_GAN_epochs.append(np.average(loss_GAN_items))loss_cycle_epochs.append(np.average(loss_cycle_items))loss_identity_epochs.append(np.average(loss_cycle_items))

(14) 训练模型后,测试模型生成图像:

generate_sample(G_AB, G_BA)

模型生成结果
从上图可以看出,CycleGAN 可以成功地将苹果转换为橙子(前两行),将橙子转换为苹果(后两行)。

小结

CycleGAN 是一种用于无监督图像转换的深度学习模型,它通过两个生成器和两个判别器的组合来学习两个不同域之间的映射关系。生成器负责将一个域的图像转换成另一个域的图像,而判别器则用于区分生成的图像和真实的图像。CycleGAN 引入循环一致性损失,确保图像转换是可逆的,从而提高生成图像的质量。通过对抗训练和循环一致性损失,CycleGAN 可以实现在没有配对标签的情况下进行图像域转换。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——从零开始实现SSD目标检测
PyTorch深度学习实战(24)——使用U-Net架构进行图像分割
PyTorch深度学习实战(25)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(26)——多对象实例分割
PyTorch深度学习实战(27)——自编码器(Autoencoder)
PyTorch深度学习实战(28)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(29)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(30)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(31)——神经风格迁移
PyTorch深度学习实战(32)——Deepfakes
PyTorch深度学习实战(33)——生成对抗网络(Generative Adversarial Network, GAN)
PyTorch深度学习实战(34)——DCGAN详解与实现
PyTorch深度学习实战(35)——条件生成对抗网络(Conditional Generative Adversarial Network, CGAN)
PyTorch深度学习实战(36)——Pix2Pix详解与实现

这篇关于PyTorch深度学习实战(37)——CycleGAN详解与实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/735227

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python脚本实现自动删除C盘临时文件夹

《Python脚本实现自动删除C盘临时文件夹》在日常使用电脑的过程中,临时文件夹往往会积累大量的无用数据,占用宝贵的磁盘空间,下面我们就来看看Python如何通过脚本实现自动删除C盘临时文件夹吧... 目录一、准备工作二、python脚本编写三、脚本解析四、运行脚本五、案例演示六、注意事项七、总结在日常使用

Java实现Excel与HTML互转

《Java实现Excel与HTML互转》Excel是一种电子表格格式,而HTM则是一种用于创建网页的标记语言,虽然两者在用途上存在差异,但有时我们需要将数据从一种格式转换为另一种格式,下面我们就来看看... Excel是一种电子表格格式,广泛用于数据处理和分析,而HTM则是一种用于创建网页的标记语言。虽然两

java图像识别工具类(ImageRecognitionUtils)使用实例详解

《java图像识别工具类(ImageRecognitionUtils)使用实例详解》:本文主要介绍如何在Java中使用OpenCV进行图像识别,包括图像加载、预处理、分类、人脸检测和特征提取等步骤... 目录前言1. 图像识别的背景与作用2. 设计目标3. 项目依赖4. 设计与实现 ImageRecogni

Java中Springboot集成Kafka实现消息发送和接收功能

《Java中Springboot集成Kafka实现消息发送和接收功能》Kafka是一个高吞吐量的分布式发布-订阅消息系统,主要用于处理大规模数据流,它由生产者、消费者、主题、分区和代理等组件构成,Ka... 目录一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者一、Kaf

Java访问修饰符public、private、protected及默认访问权限详解

《Java访问修饰符public、private、protected及默认访问权限详解》:本文主要介绍Java访问修饰符public、private、protected及默认访问权限的相关资料,每... 目录前言1. public 访问修饰符特点:示例:适用场景:2. private 访问修饰符特点:示例:

python管理工具之conda安装部署及使用详解

《python管理工具之conda安装部署及使用详解》这篇文章详细介绍了如何安装和使用conda来管理Python环境,它涵盖了从安装部署、镜像源配置到具体的conda使用方法,包括创建、激活、安装包... 目录pytpshheraerUhon管理工具:conda部署+使用一、安装部署1、 下载2、 安装3

详解Java如何向http/https接口发出请求

《详解Java如何向http/https接口发出请求》这篇文章主要为大家详细介绍了Java如何实现向http/https接口发出请求,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 用Java发送web请求所用到的包都在java.net下,在具体使用时可以用如下代码,你可以把它封装成一

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超

windos server2022里的DFS配置的实现

《windosserver2022里的DFS配置的实现》DFS是WindowsServer操作系统提供的一种功能,用于在多台服务器上集中管理共享文件夹和文件的分布式存储解决方案,本文就来介绍一下wi... 目录什么是DFS?优势:应用场景:DFS配置步骤什么是DFS?DFS指的是分布式文件系统(Distr