PyTorch深度学习实战(37)——CycleGAN详解与实现

2024-02-22 12:12

本文主要是介绍PyTorch深度学习实战(37)——CycleGAN详解与实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

PyTorch深度学习实战(37)——CycleGAN详解与实现

    • 0. 前言
    • 1. CycleGAN 基本原理
    • 2. CycleGAN 模型分析
    • 3. 实现 CycleGAN
    • 小结
    • 系列链接

0. 前言

CycleGAN 是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成另一种风格,而无需为每一对输入-输出图像配对训练数据。CycleGAN 的核心思想是利用两个生成器和两个判别器,它们共同学习两个域之间的映射关系。例如,将马的图像转换成斑马的图像,或者将夏天的风景转换成冬天的风景。在本节中,我们将学习 CycleGAN 的基本原理,并实现该模型用于将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像。

1. CycleGAN 基本原理

CycleGAN 是一种无需配对的图像转换技术,它可以将一个图像域中的图像转换为另一个图像域中的图像,而不需要匹配这两个域中的图像。它使用两个生成器和两个判别器,其中一个生成器将一个域中的图像转换为另一个域中的图像,而第二个生成器将其转换回来。这个过程被称为循环一致性,转换过程是可逆的。
CycleGAN 可以用于执行从一个类别到另一个类别的图像转换,而无需提供相匹配的输入-输出图像对来训练模型,只需要在两个不同的文件夹中提供这两个类别的图像。在本节中,我们将学习如何训练 CycleGAN 将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像,CycleGAN 中的 Cycle 是指将图像从一个类别转换到另一个类别,然后再转换回原始类别的过程。
CycleGAN 中,需要使用三种不同的损失值:

  • 鉴别器损失:用于区分真实图像和伪造图像
  • 循环一致性损失:由于 CycleGAN 使用了两个生成器,因此需要确保转换是可逆的,循环一致性损失通过将转换过的图像再次传递到原始的生成器中,并将生成的图像与原始图像进行比较来实现
  • 恒等损失 (Identity loss):确保生成器在不进行转换的情况下仍然能够生成与原始图像相似的图像,通过将原始图像传递到生成器中,并计算生成图像与原始图像之间的差异

2. CycleGAN 模型分析

CycleGAN 模型构建策略如下:

  1. 导入数据集并进行预处理
  2. 定义 UNet 架构用于构建生成器和判别器网络
  3. 定义两个生成器:
    • G_AB:将类别 A 图像转换为类别 B 图像的生成器
    • G_BA:将类别 B 图像转换为类别 A 图像的生成器
  4. 定义恒等损失:
    • 如果将一张橘子的图像输入到橙子生成器,理想情况下,如果生成器完全理解橙子的所有信息,它不应该改变图像,而应该“生成”完全相同的图像,据此,我们可以创建一个恒等变换
    • 当类别 A (real_A) 的图像通过 G_BA 并与 real_A 进行比较时,恒等损失应该是最小的
    • 当类别 B (real_B) 的图像通过 G_AB 并与 real_B 进行比较时,恒等损失应该是最小的
  5. 定义GAN损失:
    • real_Afake_A 的判别器和生成器损失(当 real_B 图像通过 G_BA 时得到 fake_A)
    • real_Bfake_B 的判别器和生成器损失(当 real_A 图像通过 G_AB 时得到 fake_B)
  6. 定义循环一致性损失:
    • 一张苹果图像需要通过橙子生成网络进行转换,生成伪造的橘子图像,然后再通过苹果生成网络将伪造的橙子图像转换回苹果图像
    • fake_Breal_A 通过 G_AB 时的输出,当 fake_B 通过 G_BA 时应该重新生成 real_A
    • fake_Areal_B 通过 G_BA 时的输出,当 fake_A 通过 G_AB 时应该重新生成 real_B
  7. 优化三个损失函数的加权和

3. 实现 CycleGAN

接下来,我们使用 PyTorch 实现 CycleGAN 模型,用以将苹果图像转换为橙子图像,或反之将橙子图像转换为苹果图像。

(1) 导入相关数据集和库。

首先下载并解压数据集,可以自行构建数据集,也可以下载本文所用数据集,下载地址:https://pan.baidu.com/s/1iTOt2NsUQ1a3taUHjvkjfA,提取码:iuqf

可视化示例图像如下:
数据集

与 Pix2Pix 训练数据集不同,苹果和橙色图像之间不存在一一对应的关系。

导入所需的库:

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
import cv2
import random
from glob import glob
from PIL import Image
import itertools
from torchvision import transforms
device = "cuda" if torch.cuda.is_available() else "cpu"

(2) 定义图像转换管道 transform

IMAGE_SIZE = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([transforms.Resize(int(IMAGE_SIZE*1.33)),transforms.RandomCrop((IMAGE_SIZE,IMAGE_SIZE)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

(3) 定义数据集类 CycleGANDataset,以苹果图像 apple 和橙子图像 orange 文件夹为输入,提供批数据:

class CycleGANDataset(Dataset):def __init__(self, apples, oranges):self.apples = glob(apples)self.oranges = glob(oranges)def __getitem__(self, ix):apple = self.apples[ix % len(self.apples)]orange = random.choice(self.oranges)apple = Image.open(apple).convert('RGB')orange = Image.open(orange).convert('RGB')return apple, orangedef __len__(self):return max(len(self.apples), len(self.oranges))def choose(self):return self[random.randint(len(self))]def collate_fn(self, batch):srcs, trgs = list(zip(*batch))srcs = torch.cat([transform(img)[None] for img in srcs], 0).to(device).float()trgs = torch.cat([transform(img)[None] for img in trgs], 0).to(device).float()return srcs.to(device), trgs.to(device)

(4) 定义训练、验证数据集和数据加载器:

trn_ds = CycleGANDataset('apples_oranges/apples_train/*.jpg', 'apples_oranges/oranges_train/*.jpg')
val_ds = CycleGANDataset('apples_oranges/apples_test/*.jpg', 'apples_oranges/oranges_test/*.jpg')trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=5, shuffle=True, collate_fn=val_ds.collate_fn)

(5) 定义网络的权重初始化方法 weights_init_normal

def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, "bias") and m.bias is not None:torch.nn.init.constant_(m.bias.data, 0.0)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)

(6) 定义残差块 ResidualBlock

class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),)def forward(self, x):return x + self.block(x)

(7) 定义生成器 GeneratorResNet

class GeneratorResNet(nn.Module):def __init__(self, num_residual_blocks=9):super(GeneratorResNet, self).__init__()out_features = 64channels = 3model = [nn.ReflectionPad2d(3),nn.Conv2d(channels, out_features, 7),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Downsamplingfor _ in range(2):out_features *= 2model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Residual blocksfor _ in range(num_residual_blocks):model += [ResidualBlock(out_features)]# Upsamplingfor _ in range(2):out_features //= 2model += [nn.Upsample(scale_factor=2),nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),nn.InstanceNorm2d(out_features),nn.ReLU(inplace=True),]in_features = out_features# Output layermodel += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]self.model = nn.Sequential(*model)self.apply(weights_init_normal)def forward(self, x):return self.model(x)

(8) 定义判别器 Discriminator

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()channels, height, width = 3, IMAGE_SIZE, IMAGE_SIZEdef discriminator_block(in_filters, out_filters, normalize=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalize:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(channels, 64, normalize=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1))self.apply(weights_init_normal)def forward(self, img):return self.model(img)

(9) 定义图像样本生成函数 generate_sample

@torch.no_grad()
def generate_sample(G_AB, G_BA):data = next(iter(val_dl))G_AB.eval()G_BA.eval()real_A, real_B = datafake_B = G_AB(real_A)fake_A = G_BA(real_B)# Arange images along x-axisreal_A = make_grid(real_A, nrow=5, normalize=True)real_B = make_grid(real_B, nrow=5, normalize=True)fake_A = make_grid(fake_A, nrow=5, normalize=True)fake_B = make_grid(fake_B, nrow=5, normalize=True)# Arange images along y-axisimage_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)plt.imshow(image_grid.detach().cpu().permute(1,2,0).numpy())plt.show()

(10) 定义生成器训练函数 generator_train_step

该函数将两个生成器( G_ABG_BA)、优化器和两个类别的真实图像( real_Areal_B )作为输入:

def generator_train_step(Gs, optimizer, real_A, real_B, D_A, D_B, criterion_identity, criterion_cycle, criterion_GAN, lambda_cyc, lambda_id):

指定生成器:

    G_AB, G_BA = Gs

将优化器的梯度设置为零:

    optimizer.zero_grad()

计算类别 A (苹果)和类别 B (橙子)图像的恒等损失 (loss_identity):

    loss_id_A = criterion_identity(G_BA(real_A), real_A)loss_id_B = criterion_identity(G_AB(real_B), real_B)loss_identity = (loss_id_A + loss_id_B) / 2

计算图像通过生成器时的 GAN 损失,此时生成的图像应尽可能接近另一类别,使用 np.ones 作为训练生成器的判别网络目标输出,因为我们将生成的伪造图像传递给相同类别的判别器:

    fake_B = G_AB(real_A)loss_GAN_AB = criterion_GAN(D_B(fake_B), torch.Tensor(np.ones((len(real_A), 1, 16, 16))).to(device))fake_A = G_BA(real_B)loss_GAN_BA = criterion_GAN(D_A(fake_A), torch.Tensor(np.ones((len(real_A), 1, 16, 16))).to(device))loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

计算循环一致性损失。假设,一张苹果图像被橙子生成器转换为一张伪造橙子图像,然后伪造橙子图像通过苹果生成器转换回一张苹果图像,理想情况下,经过该过程后应该返回原始图像,即循环一致性损失应该为零:

    recov_A = G_BA(fake_B)loss_cycle_A = criterion_cycle(recov_A, real_A)recov_B = G_AB(fake_A)loss_cycle_B = criterion_cycle(recov_B, real_B)loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

计算总损失并执行反向传播:

    loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identityloss_G.backward()optimizer.step()return loss_G, loss_identity, loss_GAN, loss_cycle, loss_G, fake_A, fake_B

(11) 定义判别器训练函数 discriminator_train_step

def discriminator_train_step(D, real_data, fake_data, optimizer, criterion_GAN):optimizer.zero_grad()loss_real = criterion_GAN(D(real_data), torch.Tensor(np.ones((len(real_data), 1, 16, 16))).to(device))loss_fake = criterion_GAN(D(fake_data.detach()), torch.Tensor(np.zeros((len(real_data), 1, 16, 16))).to(device))loss_D = (loss_real + loss_fake) / 2loss_D.backward()optimizer.step()return loss_D

(12) 定义生成器、判别器对象、优化器和损失函数:

G_AB = GeneratorResNet().to(device)
G_BA = GeneratorResNet().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))lambda_cyc, lambda_id = 10.0, 5.0

(13) 训练网络:

n_epochs = 50
# log = Report(n_epochs)
loss_D_epochs = []
loss_G_epochs = []
loss_GAN_epochs = []
loss_cycle_epochs = []
loss_identity_epochs = []
for epoch in range(n_epochs):N = len(trn_dl)loss_D_items = []loss_G_items = []loss_GAN_items = []loss_cycle_items = []loss_identity_items = []for bx, batch in enumerate(trn_dl):real_A, real_B = batchloss_G, loss_identity, loss_GAN, loss_cycle, loss_G, fake_A, fake_B = generator_train_step((G_AB,G_BA), optimizer_G, real_A, real_B, D_A, D_B, criterion_identity, criterion_cycle, criterion_GAN, lambda_cyc, lambda_id)loss_D_A = discriminator_train_step(D_A, real_A, fake_A, optimizer_D_A, criterion_GAN)loss_D_B = discriminator_train_step(D_B, real_B, fake_B, optimizer_D_B, criterion_GAN)loss_D = (loss_D_A + loss_D_B) / 2loss_D_items.append(loss_D.item())loss_G_items.append(loss_G.item())loss_GAN_items.append(loss_GAN.item())loss_cycle_items.append(loss_cycle.item())loss_identity_items.append(loss_identity.item())loss_D_epochs.append(np.average(loss_D_items))loss_G_epochs.append(np.average(loss_G_items))loss_GAN_epochs.append(np.average(loss_GAN_items))loss_cycle_epochs.append(np.average(loss_cycle_items))loss_identity_epochs.append(np.average(loss_cycle_items))

(14) 训练模型后,测试模型生成图像:

generate_sample(G_AB, G_BA)

模型生成结果
从上图可以看出,CycleGAN 可以成功地将苹果转换为橙子(前两行),将橙子转换为苹果(后两行)。

小结

CycleGAN 是一种用于无监督图像转换的深度学习模型,它通过两个生成器和两个判别器的组合来学习两个不同域之间的映射关系。生成器负责将一个域的图像转换成另一个域的图像,而判别器则用于区分生成的图像和真实的图像。CycleGAN 引入循环一致性损失,确保图像转换是可逆的,从而提高生成图像的质量。通过对抗训练和循环一致性损失,CycleGAN 可以实现在没有配对标签的情况下进行图像域转换。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——从零开始实现SSD目标检测
PyTorch深度学习实战(24)——使用U-Net架构进行图像分割
PyTorch深度学习实战(25)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(26)——多对象实例分割
PyTorch深度学习实战(27)——自编码器(Autoencoder)
PyTorch深度学习实战(28)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(29)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(30)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(31)——神经风格迁移
PyTorch深度学习实战(32)——Deepfakes
PyTorch深度学习实战(33)——生成对抗网络(Generative Adversarial Network, GAN)
PyTorch深度学习实战(34)——DCGAN详解与实现
PyTorch深度学习实战(35)——条件生成对抗网络(Conditional Generative Adversarial Network, CGAN)
PyTorch深度学习实战(36)——Pix2Pix详解与实现

这篇关于PyTorch深度学习实战(37)——CycleGAN详解与实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/735227

相关文章

网页解析 lxml 库--实战

lxml库使用流程 lxml 是 Python 的第三方解析库,完全使用 Python 语言编写,它对 XPath表达式提供了良好的支 持,因此能够了高效地解析 HTML/XML 文档。本节讲解如何通过 lxml 库解析 HTML 文档。 pip install lxml lxm| 库提供了一个 etree 模块,该模块专门用来解析 HTML/XML 文档,下面来介绍一下 lxml 库

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

性能分析之MySQL索引实战案例

文章目录 一、前言二、准备三、MySQL索引优化四、MySQL 索引知识回顾五、总结 一、前言 在上一讲性能工具之 JProfiler 简单登录案例分析实战中已经发现SQL没有建立索引问题,本文将一起从代码层去分析为什么没有建立索引? 开源ERP项目地址:https://gitee.com/jishenghua/JSH_ERP 二、准备 打开IDEA找到登录请求资源路径位置

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象