CycleGan总结及简易复现

2024-01-28 07:08
文章标签 总结 复现 简易 cyclegan

本文主要是介绍CycleGan总结及简易复现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CycleGan总结及代码简易复现

    • 简介
    • 拓展: 回归损失函数的对比:L1 loss, L2 loss(MSE)以及Smooth L1 Loss的对比
    • CycleGan网络结构

CycleGan论文地址: https://arxiv.org/abs/1703.10593

简介

2017年以前的GAN都是通过配对好的一组图片去训练模型的,但是想要获得大量的成对图片比较难,而CycleGan是无监督生成对抗网络,其实是做的是一个domain adaption的工作,可以通过一些不配对的两组图片利用生成器-判别器模型和计算它的循环损失实现领域的自适应。即把原始图像(如马)导入生成器G1(马→斑马)生成目标图像(斑马),再把目标图像当作F(斑马→马)的输入,计算生成新的图像(马~)与最初的原始图像(马)的差别,即损失,让该损失尽可能地小即能确保生成器不会生成与原始图像无关的图片。如下图所示:
在这里插入图片描述
所以总的损失函数就是L = 两个生成器的损失(G1_loss + F_loss)+两个循环损失(cycle1_loss + cycle2_loss)+ 两个identity损失(即往G1输入斑马的图片,计算生成后的斑马图片与输入的真实斑马图片的差距,同理往F输入马的图片,且此项有时可以省去来提高计算效率)
生成器的损失用MSE,循环损失与identity损失用L1函数。

拓展: 回归损失函数的对比:L1 loss, L2 loss(MSE)以及Smooth L1 Loss的对比

L1 loss函数:指的是模型预测值f(x)和真实值y之间距离的均值,公式为:在这里插入图片描述
图像:在这里插入图片描述
由图像可知:
①当损失很小时,其梯度比较大,不利于模型的训练和收敛
②无论对于什么样的输入值,其梯度都是固定的,所以不会产生梯度爆炸的问题,也就是对偏离真实样本的比较大的值不怎么敏感,有利于模型的稳定。
③在y-f(x)= 0 处不可导,可能影响收敛

L2 loss函数:模型预测值f(x) 与真实样本值y 之间差值平方的均值。
公式:在这里插入图片描述
图像:在这里插入图片描述
由图可知:
①函数在所有输入范围内都是连续的
②随着损失的减小,梯度也在减小,这有利于模型的快速收敛
③对离群点比较敏感,受其影响比较大

Smooth L1 loss函数:
在Faster-Rcnn和SSD中都用到了该函数。
公式:

x为真实值与预测值的差值

图像:在这里插入图片描述
可以看出Smooth loss函数为前两者的结合,取其精华去其糟粕。

Smooth L1的优点;
①相比于L1损失函数,可以收敛得更快。
②相比于L2损失函数,对离群点、异常值不敏感,梯度变化相对更小,训练时不容易跑飞。

CycleGan网络结构

在这里插入图片描述
生成器的网络可简化为:

一个卷积块
两个下采样块
九个残差模块
2个上采样模块
一个卷积块(output_channel = 3)
经过tanh模块(将特征图的值归为-1至1之间)

代码如下:

import torch
import torch.nn as nnclass ConvBlock(nn.Module):def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)if downelse nn.ConvTranspose2d(in_channels, out_channels, **kwargs),nn.InstanceNorm2d(out_channels),nn.ReLU(inplace=True) if use_act else nn.Identity())def forward(self, x):return self.conv(x)class ResidualBlock(nn.Module):def __init__(self, channels):super().__init__()self.block = nn.Sequential(ConvBlock(channels, channels, kernel_size=3, padding=1),ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),)def forward(self, x):return x + self.block(x)class Generator(nn.Module):def __init__(self, img_channels, num_features = 64, num_residuals=9):super().__init__()self.initial = nn.Sequential(nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),nn.InstanceNorm2d(num_features),nn.ReLU(inplace=True),)self.down_blocks = nn.ModuleList([ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),])self.res_blocks = nn.Sequential(*[ResidualBlock(num_features*4) for _ in range(num_residuals)])self.up_blocks = nn.ModuleList([ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),])self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")def forward(self, x):x = self.initial(x)for layer in self.down_blocks:x = layer(x)x = self.res_blocks(x)for layer in self.up_blocks:x = layer(x)return torch.tanh(self.last(x))

判别器的网络
同理可得:总共5层卷积层,目标是生成特征图里面的值为0-1之间,方便待会跟生成器网络生成的图进行损失计算。代码如下:

import torch
import torch.nn as nnclass Block(nn.Module):def __init__(self, in_channels, out_channels, stride):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),nn.InstanceNorm2d(out_channels),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x):return self.conv(x)class Discriminator(nn.Module):def __init__(self, in_channels=3, features=[64, 128, 256, 512]):super().__init__()self.initial = nn.Sequential(nn.Conv2d(in_channels,features[0],kernel_size=4,stride=2,padding=1,padding_mode="reflect",),nn.LeakyReLU(0.2, inplace=True),)layers = []in_channels = features[0]for feature in features[1:]:layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))in_channels = featurelayers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))self.model = nn.Sequential(*layers)def forward(self, x):x = self.initial(x)return torch.sigmoid(self.model(x))

训练模块和载入数据集的模块可以仿照原论文进行编写。

这篇关于CycleGan总结及简易复现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/652863

相关文章

C# List.Sort四种重载总结

《C#List.Sort四种重载总结》本文详细分析了C#中List.Sort()方法的四种重载形式及其实现原理,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友... 目录1. Sort方法的四种重载2. 具体使用- List.Sort();- IComparable

SpringBoot项目整合Netty启动失败的常见错误总结

《SpringBoot项目整合Netty启动失败的常见错误总结》本文总结了SpringBoot集成Netty时常见的8类问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一、端口冲突问题1. Tomcat与Netty端口冲突二、主线程被阻塞问题1. Netty启动阻

SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)

《SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)》本文总结了SpringBoot项目整合Kafka启动失败的常见错误,包括Kafka服务器连接问题、序列化配置错误、依赖配置问题、... 目录一、Kafka服务器连接问题1. Kafka服务器无法连接2. 开发环境与生产环境网络不通二、序

python3中正则表达式处理函数用法总结

《python3中正则表达式处理函数用法总结》Python中的正则表达式是一个强大的文本处理工具,用于匹配、查找、替换等操作,在Python中正则表达式的操作主要通过内置的re模块来实现,这篇文章主要... 目录前言re.match函数re.search方法re.match 与 re.search的区别检索

C++实现一个简易线程池的使用小结

《C++实现一个简易线程池的使用小结》在现代软件开发中,多线程编程已经成为提升程序性能的常见手段,本文主要介绍了C++实现一个简易线程池的使用小结,感兴趣的可以了解一下... 在现代软件开发中,多线程编程已经成为提升程序性能的常见手段。无论是处理大量 I/O 请求的服务器,还是进行 CPU 密集型计算的应用

Python版本与package版本兼容性检查方法总结

《Python版本与package版本兼容性检查方法总结》:本文主要介绍Python版本与package版本兼容性检查方法的相关资料,文中提供四种检查方法,分别是pip查询、conda管理、PyP... 目录引言为什么会出现兼容性问题方法一:用 pip 官方命令查询可用版本方法二:conda 管理包环境方法

pycharm跑python项目易出错的问题总结

《pycharm跑python项目易出错的问题总结》:本文主要介绍pycharm跑python项目易出错问题的相关资料,当你在PyCharm中运行Python程序时遇到报错,可以按照以下步骤进行排... 1. 一定不要在pycharm终端里面创建环境安装别人的项目子模块等,有可能出现的问题就是你不报错都安装

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

Spring 依赖注入与循环依赖总结

《Spring依赖注入与循环依赖总结》这篇文章给大家介绍Spring依赖注入与循环依赖总结篇,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Spring 三级缓存解决循环依赖1. 创建UserService原始对象2. 将原始对象包装成工

MySQL中查询和展示LONGBLOB类型数据的技巧总结

《MySQL中查询和展示LONGBLOB类型数据的技巧总结》在MySQL中LONGBLOB是一种二进制大对象(BLOB)数据类型,用于存储大量的二进制数据,:本文主要介绍MySQL中查询和展示LO... 目录前言1. 查询 LONGBLOB 数据的大小2. 查询并展示 LONGBLOB 数据2.1 转换为十