基于CycleGAN的图像风格转换

2024-06-08 08:36
文章标签 图像 转换 风格 cyclegan

本文主要是介绍基于CycleGAN的图像风格转换,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

基于CycleGAN的图像风格转换

  • 1.导入所需要的包和库:
  • 2.将一个Tensor转换为图像:
  • 3.数据加载:
  • 4.图像变换:
  • 5.加载和预处理训练数据:
  • 6.定义了一个残差块:
  • 7.生成器:
  • 8.判断器:
  • 9.数据缓存器:
  • 10.执行生成器的训练步骤:
  • 11.训练判别器:
  • 12.损失打印,存储伪造图片:

1.导入所需要的包和库:

from random import randint
import numpy as np 
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools

2.将一个Tensor转换为图像:

def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)  out = out.view(-1, 3, 256, 256)  return out

3.数据加载:

data_path = os.path.abspath('D:/XUNLJ/data')
image_size = 256
batch_size = 1

4.图像变换:

  • 首先,图像会被调整到略大于原始大小,然后随机裁剪回原始大小,接着进行水平翻转,转换为张量格式,最后进行标准化处理
transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), Image.BICUBIC), transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

5.加载和预处理训练数据:

  • 文件夹中随机选择一批A类和B类图像,应用预定义的图像变换,并将它们转换为适合神经网络输入的张量格式
def _get_train_data(batch_size=1):train_a_filepath = data_path + '\\trainA\\'train_b_filepath = data_path + '\\trainB\\'train_a_list = os.listdir(train_a_filepath)train_b_list = os.listdir(train_b_filepath)train_a_result = []train_b_result = [] numlist = random.sample(range(0, len(train_a_list)), batch_size)for i in numlist:a_filename = train_a_list[i]a_img = Image.open(train_a_filepath + a_filename).convert('RGB')res_a_img = transform(a_img)train_a_result.append(torch.unsqueeze(res_a_img, 0))b_filename = train_b_list[i]b_img = Image.open(train_b_filepath + b_filename).convert('RGB')res_b_img = transform(b_img)train_b_result.append(torch.unsqueeze(res_b_img, 0))return torch.cat(train_a_result, dim=0), torch.cat(train_b_result, dim=0)

6.定义了一个残差块:

  • 定义了一个简单的残差块,它包含两个卷积层和实例归一化,以及ReLU激活函数
class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block_layer = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features))def forward(self, x):return x + self.block_layer(x)

7.生成器:

  • 网络包含卷积层、下采样层、残差块和上采样层,用于将噪声输入转换为高质量的图像输出
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()model = [nn.ReflectionPad2d(3), nn.Conv2d(3, 64, 7), nn.InstanceNorm2d(64), nn.ReLU(inplace=True)]in_features = 64out_features = in_features * 2for _ in range(2):model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]in_features = out_featuresout_features = in_features*2for _ in range(9):model += [ResidualBlock(in_features)]out_features = in_features // 2for _ in range(2):model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]in_features = out_featuresout_features = in_features // 2model += [nn.ReflectionPad2d(3), nn.Conv2d(64, 3, 7), nn.Tanh()]self.gen = nn.Sequential( * model)def forward(self, x):x = self.gen(x)return x 

8.判断器:

  • 用于判断输入图像的真实性,含卷积层和LeakyReLU激活函数,用于从输入图像中提取特征,通过平均池化和重塑来生成一个与图像真实性相关的分数
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.dis = nn.Sequential(nn.Conv2d(3, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.InstanceNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.InstanceNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, 4, padding=1),nn.InstanceNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, 4, padding=1))        def forward(self, x):x = self.dis(x)return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

9.数据缓存器:

  • 用于存储和复用生成器生成的图像
class ReplayBuffer():def __init__(self, max_size=50):self.max_size = max_sizeself.data = []
  • 将新的数据推入缓存,并弹出旧的数据;如果缓存未满,则将数据推入缓存。如果缓存已满,则随机替换缓存中的一个数据。
   def push_and_pop(self, data):to_return = []for element in data.data:element = torch.unsqueeze(element, 0)if len(self.data) < self.max_size:self.data.append(element)to_return.append(element)else:if random.uniform(0,1) > 0.5:i = random.randint(0, self.max_size-1)to_return.append(self.data[i].clone())self.data[i] = elementelse:to_return.append(element)return Variable(torch.cat(to_return))
  • 实例化ReplayBuffer类,分别用于存储生成的A类和B类图像
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
  • 定义生成器网络,用于从A类图像生成B类图像
netG_A2B = Generator()
netG_B2A = Generator()
  • 定义判别器网络,用于判断A类和B类图像的真实性
netD_A = Discriminator()
netD_B = Discriminator()
  • 定义GAN损失函数和循环一致性损失函数
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
  • 定义身份损失函数
criterion_identity = torch.nn.L1Loss()
  • 定义优化器的参数
d_learning_rate = 3e-4  # 3e-4
  • 定义生成器和判别器的学习器
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 
lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)
  • 定义训练的轮数
num_epochs = 1000 

10.执行生成器的训练步骤:

  • 计算多个损失函数的值,综合考虑了图像的身份、对抗和循环一致性,来生成更真实的图像
same_B = netG_A2B(real_b).float()loss_identity_B = criterion_identity(same_B, real_b) * 5.0   same_A = netG_B2A(real_a).float()loss_identity_A = criterion_identity(same_A, real_a) * 5.0fake_B = netG_A2B(real_a).float()pred_fake = netD_B(fake_B).float()loss_GAN_A2B = criterion_GAN(pred_fake, target_real)fake_A = netG_B2A(real_b).float()pred_fake = netD_A(fake_A).float()loss_GAN_B2A = criterion_GAN(pred_fake, target_real)recovered_A = netG_B2A(fake_B).float()loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0recovered_B = netG_A2B(fake_A).float()loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0  loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)loss_G.backward()    g_optimizer.step()

11.训练判别器:

  • 训练判别器A:通过计算真实图像和生成图像的对抗损失,来训练判别器以更准确地进行区分
da_optimizer.zero_grad()pred_real = netD_A(real_a).float()loss_D_real = criterion_GAN(pred_real, target_real)fake_A = fake_A_buffer.push_and_pop(fake_A)pred_fake = netD_A(fake_A.detach()).float()loss_D_fake = criterion_GAN(pred_fake, target_fake)loss_D_A = (loss_D_real + loss_D_fake) * 0.5loss_D_A.backward()da_optimizer.step()

训练判别器B:

db_optimizer.zero_grad()pred_real = netD_B(real_b)loss_D_real = criterion_GAN(pred_real, target_real)fake_B = fake_B_buffer.push_and_pop(fake_B)pred_fake = netD_B(fake_B.detach())loss_D_fake = criterion_GAN(pred_fake, target_fake)loss_D_B = (loss_D_real + loss_D_fake) * 0.5loss_D_B.backward()db_optimizer.step()

12.损失打印,存储伪造图片:

print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'.format(epoch, loss_G.data.item(), loss_D_A.data.item(), loss_D_B.data.item()))if (epoch + 1) % 20 == 0 or epoch == 0:  b_fake = to_img(fake_B.data)a_fake = to_img(fake_A.data)a_real = to_img(real_a.data)b_real = to_img(real_b.data)save_image(a_fake, '../tmp/a_fake.png') save_image(b_fake, '../tmp/b_fake.png') save_image(a_real, '../tmp/a_real.png') save_image(b_real, '../tmp/b_real.png') 

这篇关于基于CycleGAN的图像风格转换的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1041679

相关文章

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

PDF 软件如何帮助您编辑、转换和保护文件。

如何找到最好的 PDF 编辑器。 无论您是在为您的企业寻找更高效的 PDF 解决方案,还是尝试组织和编辑主文档,PDF 编辑器都可以在一个地方提供您需要的所有工具。市面上有很多 PDF 编辑器 — 在决定哪个最适合您时,请考虑这些因素。 1. 确定您的 PDF 文档软件需求。 不同的 PDF 文档软件程序可以具有不同的功能,因此在决定哪个是最适合您的 PDF 软件之前,请花点时间评估您的

C# double[] 和Matlab数组MWArray[]转换

C# double[] 转换成MWArray[], 直接赋值就行             MWNumericArray[] ma = new MWNumericArray[4];             double[] dT = new double[] { 0 };             double[] dT1 = new double[] { 0,2 };

Verybot之OpenCV应用一:安装与图像采集测试

在Verybot上安装OpenCV是很简单的,只需要执行:         sudo apt-get update         sudo apt-get install libopencv-dev         sudo apt-get install python-opencv         下面就对安装好的OpenCV进行一下测试,编写一个通过USB摄像头采

【python计算机视觉编程——7.图像搜索】

python计算机视觉编程——7.图像搜索 7.图像搜索7.1 基于内容的图像检索(CBIR)从文本挖掘中获取灵感——矢量空间模型(BOW表示模型)7.2 视觉单词**思想****特征提取**: 创建词汇7.3 图像索引7.3.1 建立数据库7.3.2 添加图像 7.4 在数据库中搜索图像7.4.1 利用索引获取获选图像7.4.2 用一幅图像进行查询7.4.3 确定对比基准并绘制结果 7.

数据流与Bitmap之间相互转换

把获得的数据流转换成一副图片(Bitmap) 其原理就是把获得倒的数据流序列化到内存中,然后经过加工,在把数据从内存中反序列化出来就行了。 难点就是在如何实现加工。因为Bitmap有一个专有的格式,我们常称这个格式为数据头。加工的过程就是要把这个数据头与我们之前获得的数据流合并起来。(也就是要把这个头加入到我们之前获得的数据流的前面)      那么这个头是

【python计算机视觉编程——8.图像内容分类】

python计算机视觉编程——8.图像内容分类 8.图像内容分类8.1 K邻近分类法(KNN)8.1.1 一个简单的二维示例8.1.2 用稠密SIFT作为图像特征8.1.3 图像分类:手势识别 8.2贝叶斯分类器用PCA降维 8.3 支持向量机8.3.2 再论手势识别 8.4 光学字符识别8.4.2 选取特征8.4.3 多类支持向量机8.4.4 提取单元格并识别字符8.4.5 图像校正

在 Qt Creator 中,输入 /** 并按下Enter可以自动生成 Doxygen 风格的注释

在 Qt Creator 中,当你输入 /** 时,确实会自动补全标准的 Doxygen 风格注释。这是因为 Qt Creator 支持 Doxygen 以及类似的文档注释风格,并且提供了代码自动补全功能。 以下是如何在 Qt Creator 中使用和显示这些注释标记的步骤: 1. 自动补全 Doxygen 风格注释 在 Qt Creator 中,你可以这样操作: 在你的代码中,将光标放在

一个图形引擎的画面风格是由那些因素(技术)决定的?

可能很多人第一直覺會認為shader決定了視覺風格,但我認為可以從多個方面去考慮。 1. 幾何模型 一個畫面由多個成分組成,最基本的應該是其結構,在圖形學中通常稱為幾何模型。 一些引擎,如Quake/UE,有比較強的Brush建模功能(或應稱作CSG),製作建築比較方便。而CE則有較強的大型地表、植被、水體等功能,做室外自然環境十分出色。而另一些遊戲類型專用的引擎,例

高斯平面直角坐标讲解,以及地理坐标转换高斯平面直角坐标

高斯平面直角坐标系(Gauss-Krüger 坐标系)是基于 高斯-克吕格投影 的一种常见的平面坐标系统,主要用于地理信息系统 (GIS)、测绘和工程等领域。该坐标系将地球表面的经纬度(地理坐标)通过一种投影方式转换为平面直角坐标,以便在二维平面中进行距离、面积和角度的计算。 一 投影原理 高斯平面直角坐标系使用的是 高斯-克吕格投影(Gauss-Krüger Projection),这是 横