CycleGAN 论文阅读及代码实现

2023-12-01 21:20

本文主要是介绍CycleGAN 论文阅读及代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

介绍

CycleGAN是2018年发表于ICCV17的一篇论文,可以让2个图片相互转化,也就是风格迁移,如马变为斑马,斑马变为马。
在这里插入图片描述

网络结构

在这里插入图片描述CycleGAN总结构有4个网络,第一个为生成网络G:X—>Y;第二个网络为生成网络F:X—>Y。第三个网络为对抗网络命名为Dx,鉴别输入图像是否为X;第四个网络为对抗网络命名为Dy,鉴别输入图像是不是Y。如图,以马(X)和斑马(Y)为例,G网络将马的图像转化为斑马图像;F网络将斑马的图像转化为马的图像;Dx网络鉴别输入的图像是不是马;Dy网络鉴别输入图像是不是斑马。这4个网络仅有2个网络结构,即G和F都是生成网络,这两者的网络结构相同,Dx和Dy都是对抗性网络,这两者的网络结构相同。

Generator-生成网络

在这里插入图片描述以上网络主要有3种操作,卷积,反卷积和残差模块;卷积和反卷积后通常还有BN,激活函数等。

卷积

在这里插入图片描述

反卷积

在这里插入图片描述

残差模块

在这里插入图片描述残差网络最先是在ResNet中引出的可以有效的避免梯度消失,实现网络深度的提升。

Discriminator-对抗网络

在这里插入图片描述卷积后面通道都有BN层和激活函数,另外Discriminator的最终输出并不是0.0-1。0间的值,而是一个1616的矩阵,因此定义了这个1616矩阵的各个元素越接近0.9,则Loss越小,即是真值的概率越大。

Loss

G_loss

网络G的loss函数,由2部分组成,分别是cycle_loss和g_loss。

  • cycle_loss:G(x)生成了y’,F(G(x))即是生成的x’,则F(G(x))-x的绝对值的均值定为loss_x;F(y)生成x’,G(F(y))生成y’,则G(F(y))-y的绝对值的均值定为loss_y;cycle_loss=loss_x+loss_y;
    在这里插入图片描述- g_loss:G(x)是y’,则Dy(y’)每个元素减去0.9取平方,然后取平方均值定义为g_loss;
    在这里插入图片描述G_loss=cycle_loss+g_loss;

Dy_loss

网络Dy的loss函数,由2部分组成,分别是loss_real_y和loss_fake_y;

  • loss_real_y:Dy(y)是一个16x16矩阵,每个元素减去0.9后取平方,则各平方均值定义为loss_real_y;
  • loss_fake_y:G(x)生成一个y’,则Dy(G(x))相当于Dy(y’),也是一个16x16矩阵,矩阵每个元素取平方,则各平方均值定义为loss_fake_y;

在这里插入图片描述Dy_loss=loss_real_y+loss_fake_y;

F_loss

网络F的Loss函数,由2部分组成,分别是cycle_loss和f_loss;

  • cycle_loss同G_loss中的定义
  • f_loss:F(y)生成一个x’,Dx(x’)即Dx(F(y))是一个16x16的矩阵,每个元素减去0.9后取平方,各平方均值定义为f_loss;
    在这里插入图片描述F_loss=cycle_Loss+f_loss;

Dx_loss

网络Dx的loss函数,由两部分组成,分别是loss_real_x和loss_fake_x;

  • loss_real_x:Dx(x)是一个16x16矩阵,每个元素减去0.9后取平方,则各平方均值定义为loss_real_x;
  • loss_fake_x:F(y)生成一个x’,则Dx(F(y))相当于Dx(x’),也是一个16x16矩阵,矩阵每个元素取平方,则各平方均值定义为loss_fake_x;
    在这里插入图片描述Dx_loss= loss_real_x+ loss_fake_x;

训练

最小化[G_loss,Dy_loss,F_loss,Dx_loss]变量,实现网络优化训练。

代码实现

这里使用的损失函数和上面不是太一样,具体可以看这个工程:https://github.com/hardikbansal/CycleGAN 和这个博客:https://hardikbansal.github.io/CycleGANBlog/ 通过修改to_train和to_test参数控制训练和测试即可。

#coding=utf-8
import tensorflow as tf
import numpy as np
from scipy.misc import imsave #将数组保存到图像中
import matplotlib.pyplot as plt
import os #文件夹操作
import time
import random#函数功能:实现leakyrelu
def lrelu(x, leak=0.2, name = "lrelu"):with tf.variable_scope(name):return tf.maximum(x, leak*x)#函数功能:实现BN
def instance_norm(x):with tf.variable_scope("instance_norm"):epsilon = 1e-5mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0))out = scale*tf.div(x-mean, tf.sqrt(var + epsilon)) + offsetreturn out#函数功能:实现卷积
def general_conv2d(input, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="conv2d", do_norm=True, do_relu=True, relufactor=0):with tf.variable_scope(name):conv = tf.contrib.layers.conv2d(input, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))if do_norm:conv = instance_norm(conv)if do_relu:if relufactor == 0:conv = tf.nn.relu(conv, "relu")else:conv = lrelu(conv, relufactor, "lrelu")return conv#函数功能:实现反卷积
def general_deconv2d(input, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0):with tf.variable_scope(name):conv = tf.contrib.layers.conv2d_transpose(input, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))if do_norm:conv = instance_norm(conv)if do_relu:if relufactor == 0:conv = tf.nn.relu(conv, "relu")else:conv = lrelu(conv, relufactor, "lrelu")return conv#Building the generator->1.Encoder 2.Transformer 3.Decoderngf = 32 #生成器的第一层的filtes的个数
ndf = 64 #判别器的第一层的filtes的个数
batch_size = 1 #每次处理的图片个数
pool_size = 50 #保存最近的pool_size个图片,并随机用一张计算D_loss
img_width = 256
img_height = 256
img_depth = 3 #RGB
img_size = img_height * img_width
to_train = True
to_test = False
to_restore = True
output_path = "./output"
check_dir = "./output/checkpoints/"
max_epoch = 1000
max_images = 100
save_training_images = True#函数功能:构造残差模块
def build_resnet_block(input, dim, name="resnet"):with tf.variable_scope(name):out_res = tf.pad(input, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)return tf.nn.relu(out_res + input)#函数功能:构造包含6个参差模块作为转换器的生成网络
def build_generator_resnet_6blocks(input, name="generator"):with tf.variable_scope(name):f = 7ks = 3pad_input = tf.pad(input, [[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")o_c1 = general_conv2d(pad_input, ngf, ks, ks, 1, 1, 0.02, name="c1")o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, "SAME", name="c2")o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, "SAME", name="c3")o_r1 = build_resnet_block(o_c3, ngf*4, "r1")o_r2 = build_resnet_block(o_r1, ngf*4, "r2")o_r3 = build_resnet_block(o_r2, ngf*4, "r3")o_r4 = build_resnet_block(o_r3, ngf*4, "r4")o_r5 = build_resnet_block(o_r4, ngf*4, "r5")o_r6 = build_resnet_block(o_r5, ngf*4, "r6")o_c4 = general_deconv2d(o_r6, [batch_size, 64, 64, ngf*2], ngf*2, ks, ks, 2, 2, 0.02, "SAME", "c4")o_c5 = general_deconv2d(o_c4, [batch_size, 128, 128, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5")o_c5_pad = tf.pad(o_c5, [[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")o_c6 = general_conv2d(o_c5_pad, img_depth, f, f, 1, 1, 0.02, "VALID", "c6", do_relu=False)#Adding the tanh layerout_gen = tf.nn.tanh(o_c6, "t1")return out_gen#函数功能:构造包含6个参差模块作为转换器的生成网络
def build_generator_resnet_9blocks(input, name="generator"):with tf.variable_scope(name):f = 7ks = 3pad_input = tf.pad(input, [[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")o_c1 = general_conv2d(input, ngf, ks, ks, 1, 1, 0.02, name="c1")o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, "SAME", name="c2")o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, "SAME", name="c3")o_r1 = build_resnet_block(o_c3, ngf*4, "r1")o_r2 = build_resnet_block(o_r1, ngf*4, "r2")o_r3 = build_resnet_block(o_r2, ngf*4, "r3")o_r4 = build_resnet_block(o_r3, ngf*4, "r4")o_r5 = build_resnet_block(o_r4, ngf*4, "r5")o_r6 = build_resnet_block(o_r5, ngf*4, "r6")o_r7 = build_resnet_block(o_r6, ngf*4, "r7")o_r8 = build_resnet_block(o_r7, ngf*4, "r8")o_r9 = build_resnet_block(o_r8, ngf*4, "r9")o_c4 = general_deconv2d(o_r9, [batch_size, 128, 128, ngf*2], ngf*2, ks, ks, 2, 2, 0.02, "SAME", "c4")o_c5 = general_deconv2d(o_c4, [batch_size, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5")o_c6 = general_conv2d(o_c5, img_depth, f, f, 1, 1, 0.02, "SAME", "c6", do_relu=False)#Adding the tanh layerout_gen = tf.nn.tanh(o_c6, "t1")return out_gen#函数功能: 构造Discriminator_A->B
def build_gen_discriminator(input, name="discriminator"):with tf.variable_scope(name):f = 4o_c1 = general_conv2d(input, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm=False, relufactor=0.2)o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2) #do_norm=Trueo_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)o_c4 = general_conv2d(o_c3, ndf*8, f, f, 1, 1, 0.02, "SAME", "c4", relufactor=0.2)o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False)return o_c5#函数功能: 部分裁剪的Discriminator
def patch_discriminator(input, name="discriminator"):with tf.variable_scope(name):f = 4patch_input = tf.random_crop(input, [1,70,70,3])o_c1 = general_conv2d(patch_input, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm=False, relufactor=0.2)o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2) #do_norm=Trueo_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)o_c4 = general_conv2d(o_c3, ndf*8, f, f, 1, 1, 0.02, "SAME", "c4", relufactor=0.2)o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False)return o_c5class CycleGAN():def input_setup(self):'''函数功能能:为输入数据设置变量filenames_A/filenames_B -> takes the list of all training imagesself.images_A/self.images_B -> Input image with each values ranging from [-1,1]:return:'''#获取文件列表filenames_A = tf.train.match_filenames_once("zxy2lsx/trainA/*.jpg")print(filenames_A)self.queue_length_A = tf.size(filenames_A)print(self.queue_length_A)filenames_B = tf.train.match_filenames_once("zxy2lsx/trainB/*.jpg")print(filenames_B)self.queue_length_B = tf.size(filenames_B)print(self.queue_length_B)filename_queue_A = tf.train.string_input_producer(filenames_A) #输出字符串到一个输入管道队列filename_queue_B = tf.train.string_input_producer(filenames_B)image_reader = tf.WholeFileReader() #一个阅读器,读取整个文件,返回文件名称key,以及文件中所有的内容value_, image_file_A = image_reader.read(filename_queue_A)_, image_file_B = image_reader.read(filename_queue_B)# 将输入图像resize为[256, 256]# [N, C, W, H] 在第一个维度减去均值127.5self.image_A = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_A), [256, 256]), 127.5), 1)self.image_B = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_B), [256, 256]), 127.5), 1)def input_read(self, sess):'''函数功能:从图像文件夹中读取输入信息:param sess::return:'''#开启一个协调器coord = tf.train.Coordinator()#QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程threads = tf.train.start_queue_runners(coord=coord)num_files_A = sess.run(self.queue_length_A)num_files_B = sess.run(self.queue_length_B)self.fake_images_A = np.zeros((pool_size, 1, img_height, img_width, img_depth))self.fake_images_B = np.zeros((pool_size, 1, img_height, img_width, img_depth))self.A_input = np.zeros((max_images, batch_size, img_height, img_width, img_depth))self.B_input = np.zeros((max_images, batch_size, img_height, img_width, img_depth))for i in range(max_images):image_tensor = sess.run(self.image_A)if(image_tensor.size == img_size*batch_size*img_depth):self.A_input[i] = image_tensor.reshape((batch_size, img_height, img_width, img_depth))for i in range(max_images):image_tensor = sess.run(self.image_B)if(image_tensor.size == img_size*batch_size*img_depth):self.B_input[i] = image_tensor.reshape((batch_size, img_height, img_width, img_depth))#协调器coord发出所有线程终止信号coord.request_stop()#把开启的线程加入主线程,等待threads结束coord.join(threads)def model_setup(self):'''函数功能:为训练建立模型self.input_A/self.input_B -> Set of training images.self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_Bself.lr -> Learning rate variableself.cyc_A / self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calculate cyclic loss.:return:'''# 输入数据A和B的占位符self.input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_depth], name="input_A")self.input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_depth], name="input_B")# 用来计算损失函数self.fake_pool_A = tf.placeholder(tf.float32, [None, img_width, img_height, img_depth], name="fake_pool_A")self.fake_pool_B = tf.placeholder(tf.float32, [None, img_width, img_height, img_depth], name="fake_pool_B")self.global_step = tf.Variable(0, name="global_step", trainable=False)self.num_fake_inputs = 0self.lr = tf.placeholder(tf.float32, shape=[], name="lr")# A为马,B为斑马with tf.variable_scope("Model") as scope:self.fake_B = build_generator_resnet_9blocks(self.input_A, name="g_A") #转换成的斑马self.fake_A = build_generator_resnet_9blocks(self.input_B, name="g_B") #转换成的马self.rec_A = build_gen_discriminator(self.input_A, "d_A") # 鉴别器输出真实的马为真的概率(越接近1越好)self.rec_B = build_gen_discriminator(self.input_B, "d_B") # 鉴别器输出真实的斑马为真的概率(越接近1越好)scope.reuse_variables()self.fake_rec_A = build_gen_discriminator(self.fake_A, "d_A") # 鉴别器输出马转换为斑马再转换为马为真的概率(越接近0的概率越好)self.fake_rec_B = build_gen_discriminator(self.fake_B, "d_B") # 鉴别器输出斑马转换为马再转换为斑马为真的概率(越接近0的概率越好)self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B") # 马转换为斑马再转换为马self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A") # 斑马转换为马再转换为马scope.reuse_variables()self.fake_pool_rec_A = build_gen_discriminator(self.fake_pool_A, "d_A") #self.fake_pool_rec_B = build_gen_discriminator(self.fake_pool_B, "d_B")def loss_calc(self):'''函数功能:损失函数计算d_loss_A/d_loss_B -> loss of discriminator A/Bg_loss_A/g_loss_B -> loss of generator A/B:return:'''# Cycle损失,需要最小化输入图像向量和经过一个Cycle后转回来图像向量cyc_loss = tf.reduce_mean(tf.abs(self.input_A - self.cyc_A)) + tf.reduce_mean(tf.abs(self.input_B - self.cyc_B))# 鉴别器损失,需要将经过一个Cycle操作出来图像认为越真越好disc_loss_A = tf.reduce_mean(tf.squared_difference(self.fake_rec_A, 1))disc_loss_B = tf.reduce_mean(tf.squared_difference(self.fake_rec_B, 1))g_loss_A = cyc_loss * 10 + disc_loss_Bg_loss_B = cyc_loss * 10 + disc_loss_Ad_loss_A = (tf.reduce_mean(tf.square(self.fake_pool_rec_A)) + tf.reduce_mean(tf.squared_difference(self.rec_A, 1))) / 2.0d_loss_B = (tf.reduce_mean(tf.square(self.fake_pool_rec_B)) + tf.reduce_mean(tf.squared_difference(self.rec_B, 1))) / 2.0optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5)self.model_vars = tf.trainable_variables()d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)for var in self.model_vars:print(var.name)#为tensorboard汇总变量#tf.summary.scalar用来显示标量信息,在画loss和accuracy曲线时需要self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)def save_training_images(self, sess, epoch):if not os.path.exists("./output/imgs"):os.makedirs("./output/imgs")for i in range(0, 10):fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([self.fake_A, self.fake_B, self.cyc_A, self.cyc_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})imsave("./output/imgs/fakeB_" + str(epoch) + "_" + str(i) + ".jpg", ((fake_A_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/fakeA_" + str(epoch) + "_" + str(i) + ".jpg", ((fake_B_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/cycA_" + str(epoch) + "_" + str(i) + ".jpg", ((cyc_A_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/cycB_" + str(epoch) + "_" + str(i) + ".jpg", ((cyc_B_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/inputA_" + str(epoch) + "_" + str(i) + ".jpg", ((self.A_input[i][0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/inputB_" + str(epoch) + "_" + str(i) + ".jpg", ((self.B_input[i][0] + 1) * 127.5).astype(np.uint8))def fake_image_pool(self, num_fakes, fake, fake_pool):'''函数功能:计算每一张产生的图片的discriminator loss总和代价是十分昂贵的,为了加速训练使用了fake_pool保存之前生成的固定个数的fake_image并且随机使用其中一个计算loss'''if num_fakes < pool_size:fake_pool[num_fakes] = fakereturn fakeelse:p = random.random()if p > 0.5:random_id = random.randint(0, pool_size-1)temp = fake_pool[random_id]fake_pool[random_id] = fakereturn tempelse:return fakedef train(self):'''函数功能:训练:return:'''# 加载数据self.input_setup()# 建立网络self.model_setup()# 计算损失函数self.loss_calc()# 初始化变量init = tf.global_variables_initializer()init2 = tf.local_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)sess.run(init2)#将input读入到numpy数组self.input_read(sess)#从最近的一次checkpoint继续训练if to_restore:chkpt_frame = tf.train.latest_checkpoint(check_dir)saver.restore(sess, chkpt_frame)writer = tf.summary.FileWriter("./output/2") #记录tensorflow的默认图if not os.path.exists(check_dir):os.makedirs(check_dir)#训练循环start_time = time.time()for epoch in range(sess.run(self.global_step), max_epoch):print("In the epoch ", epoch)saver.save(sess, os.path.join(check_dir, "cyclegan"), global_step=epoch)#调整学习率if epoch < 100:curr_lr = 0.0002else:curr_lr = 0.0002 - 0.0002 *(epoch - 100) / 100if save_training_images:self.save_training_images(sess, epoch)for ptr in range(0, max_images):print("In the iteration ", ptr)#Optimizing the G_A network_, fake_B_temp, summary_str = sess.run([self.g_A_trainer, self.fake_B, self.g_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})writer.add_summary(summary_str, epoch*max_images + ptr)fake_B_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_B_temp, self.fake_images_B)#Optimizing the D_B network_, summary_str = sess.run([self.d_B_trainer, self.d_B_loss_summ], feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr],self.lr:curr_lr, self.fake_pool_B:fake_B_temp1})writer.add_summary(summary_str, epoch*max_images + ptr)#Optimizing the G_B network_, fake_A_temp, summary_str = sess.run([self.g_B_trainer, self.fake_A, self.g_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})writer.add_summary(summary_str, epoch*max_images + ptr)fake_A_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_A_temp, self.fake_images_A)print(fake_A_temp1.shape)#Optimizing the D_A network_, summary_str = sess.run([self.d_A_trainer, self.d_A_loss_summ], feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr],self.lr:curr_lr, self.fake_pool_A:fake_A_temp1})writer.add_summary(summary_str, epoch*max_images + ptr)hour = int((time.time() - start_time) / 3600)min = int(((time.time() - start_time) - 3600 * hour) / 60)sec = int((time.time() - start_time) - 3600 * hour - 60 * min)print("Time: ", hour, "h: ", min, "min", sec, "sec")self.num_fake_inputs += 1sess.run(tf.assign(self.global_step, epoch + 1))writer.add_graph(sess.graph)def test(self):'''函数功能:测试:return:'''print("Testing the results")self.input_setup()self.model_setup()saver = tf.train.Saver()init = tf.global_variables_initializer()init2 =  tf.local_variables_initializer()with tf.Session() as sess:sess.run(init)sess.run(init2)self.input_read(sess)chkpt_frame = tf.train.latest_checkpoint(check_dir)saver.restore(sess, chkpt_frame)if not os.path.exists("./output/imgs/test/"):os.makedirs("./output/imgs/test/")for i in range(0, 100):fake_A_temp, fake_B_temp = sess.run([self.fake_A, self.fake_B], feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})imsave("./output/imgs/test/fakeB_" + str(i) + ".jpg", ((fake_A_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/test/fakeA_" + str(i) + ".jpg", ((fake_B_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/test/inputA_" + "_" + str(i) + ".jpg", ((self.A_input[i][0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/test/inputB_"  + "_" + str(i) + ".jpg", ((self.B_input[i][0] + 1) * 127.5).astype(np.uint8))if __name__ == '__main__':model = CycleGAN()if to_train:model.train()elif to_test:model.test()

效果图

  • 马和斑马的转换
    在这里插入图片描述

这篇关于CycleGAN 论文阅读及代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/442737

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P