CycleGAN 论文阅读及代码实现

2023-12-01 21:20

本文主要是介绍CycleGAN 论文阅读及代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

介绍

CycleGAN是2018年发表于ICCV17的一篇论文,可以让2个图片相互转化,也就是风格迁移,如马变为斑马,斑马变为马。
在这里插入图片描述

网络结构

在这里插入图片描述CycleGAN总结构有4个网络,第一个为生成网络G:X—>Y;第二个网络为生成网络F:X—>Y。第三个网络为对抗网络命名为Dx,鉴别输入图像是否为X;第四个网络为对抗网络命名为Dy,鉴别输入图像是不是Y。如图,以马(X)和斑马(Y)为例,G网络将马的图像转化为斑马图像;F网络将斑马的图像转化为马的图像;Dx网络鉴别输入的图像是不是马;Dy网络鉴别输入图像是不是斑马。这4个网络仅有2个网络结构,即G和F都是生成网络,这两者的网络结构相同,Dx和Dy都是对抗性网络,这两者的网络结构相同。

Generator-生成网络

在这里插入图片描述以上网络主要有3种操作,卷积,反卷积和残差模块;卷积和反卷积后通常还有BN,激活函数等。

卷积

在这里插入图片描述

反卷积

在这里插入图片描述

残差模块

在这里插入图片描述残差网络最先是在ResNet中引出的可以有效的避免梯度消失,实现网络深度的提升。

Discriminator-对抗网络

在这里插入图片描述卷积后面通道都有BN层和激活函数,另外Discriminator的最终输出并不是0.0-1。0间的值,而是一个1616的矩阵,因此定义了这个1616矩阵的各个元素越接近0.9,则Loss越小,即是真值的概率越大。

Loss

G_loss

网络G的loss函数,由2部分组成,分别是cycle_loss和g_loss。

  • cycle_loss:G(x)生成了y’,F(G(x))即是生成的x’,则F(G(x))-x的绝对值的均值定为loss_x;F(y)生成x’,G(F(y))生成y’,则G(F(y))-y的绝对值的均值定为loss_y;cycle_loss=loss_x+loss_y;
    在这里插入图片描述- g_loss:G(x)是y’,则Dy(y’)每个元素减去0.9取平方,然后取平方均值定义为g_loss;
    在这里插入图片描述G_loss=cycle_loss+g_loss;

Dy_loss

网络Dy的loss函数,由2部分组成,分别是loss_real_y和loss_fake_y;

  • loss_real_y:Dy(y)是一个16x16矩阵,每个元素减去0.9后取平方,则各平方均值定义为loss_real_y;
  • loss_fake_y:G(x)生成一个y’,则Dy(G(x))相当于Dy(y’),也是一个16x16矩阵,矩阵每个元素取平方,则各平方均值定义为loss_fake_y;

在这里插入图片描述Dy_loss=loss_real_y+loss_fake_y;

F_loss

网络F的Loss函数,由2部分组成,分别是cycle_loss和f_loss;

  • cycle_loss同G_loss中的定义
  • f_loss:F(y)生成一个x’,Dx(x’)即Dx(F(y))是一个16x16的矩阵,每个元素减去0.9后取平方,各平方均值定义为f_loss;
    在这里插入图片描述F_loss=cycle_Loss+f_loss;

Dx_loss

网络Dx的loss函数,由两部分组成,分别是loss_real_x和loss_fake_x;

  • loss_real_x:Dx(x)是一个16x16矩阵,每个元素减去0.9后取平方,则各平方均值定义为loss_real_x;
  • loss_fake_x:F(y)生成一个x’,则Dx(F(y))相当于Dx(x’),也是一个16x16矩阵,矩阵每个元素取平方,则各平方均值定义为loss_fake_x;
    在这里插入图片描述Dx_loss= loss_real_x+ loss_fake_x;

训练

最小化[G_loss,Dy_loss,F_loss,Dx_loss]变量,实现网络优化训练。

代码实现

这里使用的损失函数和上面不是太一样,具体可以看这个工程:https://github.com/hardikbansal/CycleGAN 和这个博客:https://hardikbansal.github.io/CycleGANBlog/ 通过修改to_train和to_test参数控制训练和测试即可。

#coding=utf-8
import tensorflow as tf
import numpy as np
from scipy.misc import imsave #将数组保存到图像中
import matplotlib.pyplot as plt
import os #文件夹操作
import time
import random#函数功能:实现leakyrelu
def lrelu(x, leak=0.2, name = "lrelu"):with tf.variable_scope(name):return tf.maximum(x, leak*x)#函数功能:实现BN
def instance_norm(x):with tf.variable_scope("instance_norm"):epsilon = 1e-5mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02))offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0))out = scale*tf.div(x-mean, tf.sqrt(var + epsilon)) + offsetreturn out#函数功能:实现卷积
def general_conv2d(input, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="conv2d", do_norm=True, do_relu=True, relufactor=0):with tf.variable_scope(name):conv = tf.contrib.layers.conv2d(input, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))if do_norm:conv = instance_norm(conv)if do_relu:if relufactor == 0:conv = tf.nn.relu(conv, "relu")else:conv = lrelu(conv, relufactor, "lrelu")return conv#函数功能:实现反卷积
def general_deconv2d(input, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02, padding="VALID", name="deconv2d", do_norm=True, do_relu=True, relufactor=0):with tf.variable_scope(name):conv = tf.contrib.layers.conv2d_transpose(input, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev),biases_initializer=tf.constant_initializer(0.0))if do_norm:conv = instance_norm(conv)if do_relu:if relufactor == 0:conv = tf.nn.relu(conv, "relu")else:conv = lrelu(conv, relufactor, "lrelu")return conv#Building the generator->1.Encoder 2.Transformer 3.Decoderngf = 32 #生成器的第一层的filtes的个数
ndf = 64 #判别器的第一层的filtes的个数
batch_size = 1 #每次处理的图片个数
pool_size = 50 #保存最近的pool_size个图片,并随机用一张计算D_loss
img_width = 256
img_height = 256
img_depth = 3 #RGB
img_size = img_height * img_width
to_train = True
to_test = False
to_restore = True
output_path = "./output"
check_dir = "./output/checkpoints/"
max_epoch = 1000
max_images = 100
save_training_images = True#函数功能:构造残差模块
def build_resnet_block(input, dim, name="resnet"):with tf.variable_scope(name):out_res = tf.pad(input, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")out_res = general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)return tf.nn.relu(out_res + input)#函数功能:构造包含6个参差模块作为转换器的生成网络
def build_generator_resnet_6blocks(input, name="generator"):with tf.variable_scope(name):f = 7ks = 3pad_input = tf.pad(input, [[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")o_c1 = general_conv2d(pad_input, ngf, ks, ks, 1, 1, 0.02, name="c1")o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, "SAME", name="c2")o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, "SAME", name="c3")o_r1 = build_resnet_block(o_c3, ngf*4, "r1")o_r2 = build_resnet_block(o_r1, ngf*4, "r2")o_r3 = build_resnet_block(o_r2, ngf*4, "r3")o_r4 = build_resnet_block(o_r3, ngf*4, "r4")o_r5 = build_resnet_block(o_r4, ngf*4, "r5")o_r6 = build_resnet_block(o_r5, ngf*4, "r6")o_c4 = general_deconv2d(o_r6, [batch_size, 64, 64, ngf*2], ngf*2, ks, ks, 2, 2, 0.02, "SAME", "c4")o_c5 = general_deconv2d(o_c4, [batch_size, 128, 128, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5")o_c5_pad = tf.pad(o_c5, [[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")o_c6 = general_conv2d(o_c5_pad, img_depth, f, f, 1, 1, 0.02, "VALID", "c6", do_relu=False)#Adding the tanh layerout_gen = tf.nn.tanh(o_c6, "t1")return out_gen#函数功能:构造包含6个参差模块作为转换器的生成网络
def build_generator_resnet_9blocks(input, name="generator"):with tf.variable_scope(name):f = 7ks = 3pad_input = tf.pad(input, [[0, 0], [ks, ks], [ks, ks], [0, 0]], "REFLECT")o_c1 = general_conv2d(input, ngf, ks, ks, 1, 1, 0.02, name="c1")o_c2 = general_conv2d(o_c1, ngf*2, ks, ks, 2, 2, 0.02, "SAME", name="c2")o_c3 = general_conv2d(o_c2, ngf*4, ks, ks, 2, 2, 0.02, "SAME", name="c3")o_r1 = build_resnet_block(o_c3, ngf*4, "r1")o_r2 = build_resnet_block(o_r1, ngf*4, "r2")o_r3 = build_resnet_block(o_r2, ngf*4, "r3")o_r4 = build_resnet_block(o_r3, ngf*4, "r4")o_r5 = build_resnet_block(o_r4, ngf*4, "r5")o_r6 = build_resnet_block(o_r5, ngf*4, "r6")o_r7 = build_resnet_block(o_r6, ngf*4, "r7")o_r8 = build_resnet_block(o_r7, ngf*4, "r8")o_r9 = build_resnet_block(o_r8, ngf*4, "r9")o_c4 = general_deconv2d(o_r9, [batch_size, 128, 128, ngf*2], ngf*2, ks, ks, 2, 2, 0.02, "SAME", "c4")o_c5 = general_deconv2d(o_c4, [batch_size, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02, "SAME", "c5")o_c6 = general_conv2d(o_c5, img_depth, f, f, 1, 1, 0.02, "SAME", "c6", do_relu=False)#Adding the tanh layerout_gen = tf.nn.tanh(o_c6, "t1")return out_gen#函数功能: 构造Discriminator_A->B
def build_gen_discriminator(input, name="discriminator"):with tf.variable_scope(name):f = 4o_c1 = general_conv2d(input, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm=False, relufactor=0.2)o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2) #do_norm=Trueo_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)o_c4 = general_conv2d(o_c3, ndf*8, f, f, 1, 1, 0.02, "SAME", "c4", relufactor=0.2)o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False)return o_c5#函数功能: 部分裁剪的Discriminator
def patch_discriminator(input, name="discriminator"):with tf.variable_scope(name):f = 4patch_input = tf.random_crop(input, [1,70,70,3])o_c1 = general_conv2d(patch_input, ndf, f, f, 2, 2, 0.02, "SAME", "c1", do_norm=False, relufactor=0.2)o_c2 = general_conv2d(o_c1, ndf*2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2) #do_norm=Trueo_c3 = general_conv2d(o_c2, ndf*4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)o_c4 = general_conv2d(o_c3, ndf*8, f, f, 1, 1, 0.02, "SAME", "c4", relufactor=0.2)o_c5 = general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False)return o_c5class CycleGAN():def input_setup(self):'''函数功能能:为输入数据设置变量filenames_A/filenames_B -> takes the list of all training imagesself.images_A/self.images_B -> Input image with each values ranging from [-1,1]:return:'''#获取文件列表filenames_A = tf.train.match_filenames_once("zxy2lsx/trainA/*.jpg")print(filenames_A)self.queue_length_A = tf.size(filenames_A)print(self.queue_length_A)filenames_B = tf.train.match_filenames_once("zxy2lsx/trainB/*.jpg")print(filenames_B)self.queue_length_B = tf.size(filenames_B)print(self.queue_length_B)filename_queue_A = tf.train.string_input_producer(filenames_A) #输出字符串到一个输入管道队列filename_queue_B = tf.train.string_input_producer(filenames_B)image_reader = tf.WholeFileReader() #一个阅读器,读取整个文件,返回文件名称key,以及文件中所有的内容value_, image_file_A = image_reader.read(filename_queue_A)_, image_file_B = image_reader.read(filename_queue_B)# 将输入图像resize为[256, 256]# [N, C, W, H] 在第一个维度减去均值127.5self.image_A = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_A), [256, 256]), 127.5), 1)self.image_B = tf.subtract(tf.div(tf.image.resize_images(tf.image.decode_jpeg(image_file_B), [256, 256]), 127.5), 1)def input_read(self, sess):'''函数功能:从图像文件夹中读取输入信息:param sess::return:'''#开启一个协调器coord = tf.train.Coordinator()#QueueRunner类用来启动tensor的入队线程,可以用来启动多个工作线程threads = tf.train.start_queue_runners(coord=coord)num_files_A = sess.run(self.queue_length_A)num_files_B = sess.run(self.queue_length_B)self.fake_images_A = np.zeros((pool_size, 1, img_height, img_width, img_depth))self.fake_images_B = np.zeros((pool_size, 1, img_height, img_width, img_depth))self.A_input = np.zeros((max_images, batch_size, img_height, img_width, img_depth))self.B_input = np.zeros((max_images, batch_size, img_height, img_width, img_depth))for i in range(max_images):image_tensor = sess.run(self.image_A)if(image_tensor.size == img_size*batch_size*img_depth):self.A_input[i] = image_tensor.reshape((batch_size, img_height, img_width, img_depth))for i in range(max_images):image_tensor = sess.run(self.image_B)if(image_tensor.size == img_size*batch_size*img_depth):self.B_input[i] = image_tensor.reshape((batch_size, img_height, img_width, img_depth))#协调器coord发出所有线程终止信号coord.request_stop()#把开启的线程加入主线程,等待threads结束coord.join(threads)def model_setup(self):'''函数功能:为训练建立模型self.input_A/self.input_B -> Set of training images.self.fake_A/self.fake_B -> Generated images by corresponding generator of input_A and input_Bself.lr -> Learning rate variableself.cyc_A / self.cyc_B -> Images generated after feeding self.fake_A/self.fake_B to corresponding generator. This is use to calculate cyclic loss.:return:'''# 输入数据A和B的占位符self.input_A = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_depth], name="input_A")self.input_B = tf.placeholder(tf.float32, [batch_size, img_width, img_height, img_depth], name="input_B")# 用来计算损失函数self.fake_pool_A = tf.placeholder(tf.float32, [None, img_width, img_height, img_depth], name="fake_pool_A")self.fake_pool_B = tf.placeholder(tf.float32, [None, img_width, img_height, img_depth], name="fake_pool_B")self.global_step = tf.Variable(0, name="global_step", trainable=False)self.num_fake_inputs = 0self.lr = tf.placeholder(tf.float32, shape=[], name="lr")# A为马,B为斑马with tf.variable_scope("Model") as scope:self.fake_B = build_generator_resnet_9blocks(self.input_A, name="g_A") #转换成的斑马self.fake_A = build_generator_resnet_9blocks(self.input_B, name="g_B") #转换成的马self.rec_A = build_gen_discriminator(self.input_A, "d_A") # 鉴别器输出真实的马为真的概率(越接近1越好)self.rec_B = build_gen_discriminator(self.input_B, "d_B") # 鉴别器输出真实的斑马为真的概率(越接近1越好)scope.reuse_variables()self.fake_rec_A = build_gen_discriminator(self.fake_A, "d_A") # 鉴别器输出马转换为斑马再转换为马为真的概率(越接近0的概率越好)self.fake_rec_B = build_gen_discriminator(self.fake_B, "d_B") # 鉴别器输出斑马转换为马再转换为斑马为真的概率(越接近0的概率越好)self.cyc_A = build_generator_resnet_9blocks(self.fake_B, "g_B") # 马转换为斑马再转换为马self.cyc_B = build_generator_resnet_9blocks(self.fake_A, "g_A") # 斑马转换为马再转换为马scope.reuse_variables()self.fake_pool_rec_A = build_gen_discriminator(self.fake_pool_A, "d_A") #self.fake_pool_rec_B = build_gen_discriminator(self.fake_pool_B, "d_B")def loss_calc(self):'''函数功能:损失函数计算d_loss_A/d_loss_B -> loss of discriminator A/Bg_loss_A/g_loss_B -> loss of generator A/B:return:'''# Cycle损失,需要最小化输入图像向量和经过一个Cycle后转回来图像向量cyc_loss = tf.reduce_mean(tf.abs(self.input_A - self.cyc_A)) + tf.reduce_mean(tf.abs(self.input_B - self.cyc_B))# 鉴别器损失,需要将经过一个Cycle操作出来图像认为越真越好disc_loss_A = tf.reduce_mean(tf.squared_difference(self.fake_rec_A, 1))disc_loss_B = tf.reduce_mean(tf.squared_difference(self.fake_rec_B, 1))g_loss_A = cyc_loss * 10 + disc_loss_Bg_loss_B = cyc_loss * 10 + disc_loss_Ad_loss_A = (tf.reduce_mean(tf.square(self.fake_pool_rec_A)) + tf.reduce_mean(tf.squared_difference(self.rec_A, 1))) / 2.0d_loss_B = (tf.reduce_mean(tf.square(self.fake_pool_rec_B)) + tf.reduce_mean(tf.squared_difference(self.rec_B, 1))) / 2.0optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5)self.model_vars = tf.trainable_variables()d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)for var in self.model_vars:print(var.name)#为tensorboard汇总变量#tf.summary.scalar用来显示标量信息,在画loss和accuracy曲线时需要self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)def save_training_images(self, sess, epoch):if not os.path.exists("./output/imgs"):os.makedirs("./output/imgs")for i in range(0, 10):fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = sess.run([self.fake_A, self.fake_B, self.cyc_A, self.cyc_B],feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})imsave("./output/imgs/fakeB_" + str(epoch) + "_" + str(i) + ".jpg", ((fake_A_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/fakeA_" + str(epoch) + "_" + str(i) + ".jpg", ((fake_B_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/cycA_" + str(epoch) + "_" + str(i) + ".jpg", ((cyc_A_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/cycB_" + str(epoch) + "_" + str(i) + ".jpg", ((cyc_B_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/inputA_" + str(epoch) + "_" + str(i) + ".jpg", ((self.A_input[i][0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/inputB_" + str(epoch) + "_" + str(i) + ".jpg", ((self.B_input[i][0] + 1) * 127.5).astype(np.uint8))def fake_image_pool(self, num_fakes, fake, fake_pool):'''函数功能:计算每一张产生的图片的discriminator loss总和代价是十分昂贵的,为了加速训练使用了fake_pool保存之前生成的固定个数的fake_image并且随机使用其中一个计算loss'''if num_fakes < pool_size:fake_pool[num_fakes] = fakereturn fakeelse:p = random.random()if p > 0.5:random_id = random.randint(0, pool_size-1)temp = fake_pool[random_id]fake_pool[random_id] = fakereturn tempelse:return fakedef train(self):'''函数功能:训练:return:'''# 加载数据self.input_setup()# 建立网络self.model_setup()# 计算损失函数self.loss_calc()# 初始化变量init = tf.global_variables_initializer()init2 = tf.local_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)sess.run(init2)#将input读入到numpy数组self.input_read(sess)#从最近的一次checkpoint继续训练if to_restore:chkpt_frame = tf.train.latest_checkpoint(check_dir)saver.restore(sess, chkpt_frame)writer = tf.summary.FileWriter("./output/2") #记录tensorflow的默认图if not os.path.exists(check_dir):os.makedirs(check_dir)#训练循环start_time = time.time()for epoch in range(sess.run(self.global_step), max_epoch):print("In the epoch ", epoch)saver.save(sess, os.path.join(check_dir, "cyclegan"), global_step=epoch)#调整学习率if epoch < 100:curr_lr = 0.0002else:curr_lr = 0.0002 - 0.0002 *(epoch - 100) / 100if save_training_images:self.save_training_images(sess, epoch)for ptr in range(0, max_images):print("In the iteration ", ptr)#Optimizing the G_A network_, fake_B_temp, summary_str = sess.run([self.g_A_trainer, self.fake_B, self.g_A_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})writer.add_summary(summary_str, epoch*max_images + ptr)fake_B_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_B_temp, self.fake_images_B)#Optimizing the D_B network_, summary_str = sess.run([self.d_B_trainer, self.d_B_loss_summ], feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr],self.lr:curr_lr, self.fake_pool_B:fake_B_temp1})writer.add_summary(summary_str, epoch*max_images + ptr)#Optimizing the G_B network_, fake_A_temp, summary_str = sess.run([self.g_B_trainer, self.fake_A, self.g_B_loss_summ],feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr], self.lr:curr_lr})writer.add_summary(summary_str, epoch*max_images + ptr)fake_A_temp1 = self.fake_image_pool(self.num_fake_inputs, fake_A_temp, self.fake_images_A)print(fake_A_temp1.shape)#Optimizing the D_A network_, summary_str = sess.run([self.d_A_trainer, self.d_A_loss_summ], feed_dict={self.input_A:self.A_input[ptr], self.input_B:self.B_input[ptr],self.lr:curr_lr, self.fake_pool_A:fake_A_temp1})writer.add_summary(summary_str, epoch*max_images + ptr)hour = int((time.time() - start_time) / 3600)min = int(((time.time() - start_time) - 3600 * hour) / 60)sec = int((time.time() - start_time) - 3600 * hour - 60 * min)print("Time: ", hour, "h: ", min, "min", sec, "sec")self.num_fake_inputs += 1sess.run(tf.assign(self.global_step, epoch + 1))writer.add_graph(sess.graph)def test(self):'''函数功能:测试:return:'''print("Testing the results")self.input_setup()self.model_setup()saver = tf.train.Saver()init = tf.global_variables_initializer()init2 =  tf.local_variables_initializer()with tf.Session() as sess:sess.run(init)sess.run(init2)self.input_read(sess)chkpt_frame = tf.train.latest_checkpoint(check_dir)saver.restore(sess, chkpt_frame)if not os.path.exists("./output/imgs/test/"):os.makedirs("./output/imgs/test/")for i in range(0, 100):fake_A_temp, fake_B_temp = sess.run([self.fake_A, self.fake_B], feed_dict={self.input_A:self.A_input[i], self.input_B:self.B_input[i]})imsave("./output/imgs/test/fakeB_" + str(i) + ".jpg", ((fake_A_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/test/fakeA_" + str(i) + ".jpg", ((fake_B_temp[0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/test/inputA_" + "_" + str(i) + ".jpg", ((self.A_input[i][0] + 1) * 127.5).astype(np.uint8))imsave("./output/imgs/test/inputB_"  + "_" + str(i) + ".jpg", ((self.B_input[i][0] + 1) * 127.5).astype(np.uint8))if __name__ == '__main__':model = CycleGAN()if to_train:model.train()elif to_test:model.test()

效果图

  • 马和斑马的转换
    在这里插入图片描述

这篇关于CycleGAN 论文阅读及代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/442737

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python脚本实现自动删除C盘临时文件夹

《Python脚本实现自动删除C盘临时文件夹》在日常使用电脑的过程中,临时文件夹往往会积累大量的无用数据,占用宝贵的磁盘空间,下面我们就来看看Python如何通过脚本实现自动删除C盘临时文件夹吧... 目录一、准备工作二、python脚本编写三、脚本解析四、运行脚本五、案例演示六、注意事项七、总结在日常使用

Java实现Excel与HTML互转

《Java实现Excel与HTML互转》Excel是一种电子表格格式,而HTM则是一种用于创建网页的标记语言,虽然两者在用途上存在差异,但有时我们需要将数据从一种格式转换为另一种格式,下面我们就来看看... Excel是一种电子表格格式,广泛用于数据处理和分析,而HTM则是一种用于创建网页的标记语言。虽然两

Java中Springboot集成Kafka实现消息发送和接收功能

《Java中Springboot集成Kafka实现消息发送和接收功能》Kafka是一个高吞吐量的分布式发布-订阅消息系统,主要用于处理大规模数据流,它由生产者、消费者、主题、分区和代理等组件构成,Ka... 目录一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者一、Kaf

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超

windos server2022里的DFS配置的实现

《windosserver2022里的DFS配置的实现》DFS是WindowsServer操作系统提供的一种功能,用于在多台服务器上集中管理共享文件夹和文件的分布式存储解决方案,本文就来介绍一下wi... 目录什么是DFS?优势:应用场景:DFS配置步骤什么是DFS?DFS指的是分布式文件系统(Distr

NFS实现多服务器文件的共享的方法步骤

《NFS实现多服务器文件的共享的方法步骤》NFS允许网络中的计算机之间共享资源,客户端可以透明地读写远端NFS服务器上的文件,本文就来介绍一下NFS实现多服务器文件的共享的方法步骤,感兴趣的可以了解一... 目录一、简介二、部署1、准备1、服务端和客户端:安装nfs-utils2、服务端:创建共享目录3、服

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

Python实现高效地读写大型文件

《Python实现高效地读写大型文件》Python如何读写的是大型文件,有没有什么方法来提高效率呢,这篇文章就来和大家聊聊如何在Python中高效地读写大型文件,需要的可以了解下... 目录一、逐行读取大型文件二、分块读取大型文件三、使用 mmap 模块进行内存映射文件操作(适用于大文件)四、使用 pand

python实现pdf转word和excel的示例代码

《python实现pdf转word和excel的示例代码》本文主要介绍了python实现pdf转word和excel的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、引言二、python编程1,PDF转Word2,PDF转Excel三、前端页面效果展示总结一