KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读

本文主要是介绍KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本博文阅读的代码来自于I’m Something of a Painter Myself | Kaggle倾情推荐:

Monet CycleGAN Tutorial | Kaggle

数据集说明

I’m Something of a Painter Myself | Kaggle

Files

  • monet_jpg - 300 Monet paintings sized 256x256 in JPEG format
  • monet_tfrec - 300 Monet paintings sized 256x256 in TFRecord format
  • photo_jpg - 7028 photos sized 256x256 in JPEG format
  • photo_tfrec - 7028 photos sized 256x256 in TFRecord format

简单介绍一下,就是有两种类型的数据提供使用,一个是JPEG格式,一个是TFRecord,训练集的size是300,测试集的size是7028。并且每张图片的大小都是256×256。

代码阅读

首先是一些说明,这个代码使用的是TensorFlow,所以也就大概看看,后面会搬家到PyTorch写写看。使用的是CycleGAN,这个很合理,因为这里是无监督学习,不过GAN的种类有超多哎,也许会有更好的GAN可以选择呢?anyway,CycleGAN也是很经典的方法啦。

先放一张PPT:

加载数据

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))
tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))

tf.io.gfile.glob(pattern): Returns a list of files that match the given pattern(s).

查找匹配pattern的文件并以列表的形式返回,pattern可以是一个具体的文件名,也可以是包含通配符的正则表达式。

参考链接:

tf.io.gfile.glob  |  TensorFlow v2.15.0.post1 (google.cn)
TensorFlow函数教程:tf.io.gfile.glob_w3cschool
Tensorflow 2.0 gfile 文件操作 - 知乎 (zhihu.com)
tf.io.gfile.glob 遍历文件-CSDN博客
TensorFlow函数:tf.io.gfile.glob_tf.io.gfile.remove函数参数-CSDN博客

def decode_image(image):image = tf.image.decode_jpeg(image, channels=3)image = (tf.cast(image, tf.float32) / 127.5) - 1image = tf.reshape(image, [*IMAGE_SIZE, 3])return image
tf.image.decode_jpeg(image, channels=3)

将JPEG编码的图像解码为unit8的Tensor,channels取3表示返回的是RGB图像,channels取1表示返回的是灰度图像,channels取0表示使用JPEG编码图像中的通道数量。

参考链接:

加载和预处理图像  |  TensorFlow Core (google.cn)
TensorFlow函数:tf.image.decode_jpeg_w3cschool
tf.image.decode_jpeg函数与tf.image.encode_jpeg函数用法-CSDN博客

tf.cast(image, tf.float32)

将前面得到的Tensor数据类型从unit8改到float32

参考链接:

Tensorflow中 tf.cast()的用法_tensorflow.cast-CSDN博客
tensorflow——tf.cast()详解_tensorflow的cast-CSDN博客
tf.cast - TensorFlow Python - W3cubDocs

tf.reshape(image, [*IMAGE_SIZE, 3])

对image(也就是前面处理的Tensor)进行维度的调整。比如:

参考链接:

tf.reshape函数用法&理解-CSDN博客
【tensorflow】tf.reshape函数说明:重塑张量_tensorflow reshape 变大-CSDN博客
TensorFlow:使用tf.reshape函数重塑张量_w3cschool
tf.reshape(x, [-1, 28, 28, 1])_reshape((-1, 28, 28, 1)-CSDN博客
Python的reshape的用法:reshape(1,-1)-CSDN博客

def read_tfrecord(example):tfrecord_format = {"image_name": tf.io.FixedLenFeature([], tf.string),"image": tf.io.FixedLenFeature([], tf.string),"target": tf.io.FixedLenFeature([], tf.string)}example = tf.io.parse_single_example(example, tfrecord_format)image = decode_image(example['image'])return image
tf.io.FixedLenFeature([], tf.string)

解析每个输入样本的每一列数据

参考链接:

TFRecord 中 FixedLenFeature、VarLenFeature、FixedLenSequenceFeature 说明_tf.fixedlenfeature-CSDN博客
Tensorflow2.0之TFRecord文件的写入与读取_tensorflow2 使用tfrecord-CSDN博客
TensorFlow函数教程:tf.io.FixedLenFeature_w3cschool

tf.io.parse_single_example(example, tfrecord_format)

输入一个Tensor,输出一个dict

使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

参考链接:

tensorflow2.0 环境下的tfrecord读写及tf.io.parse_example和tf.io.parse_single_example的区别-CSDN博客
TensorFlow2.0 TFrecord数据集的写入、读取和训练示例详解_tensorflow将图片数据写入tfrecord-CSDN博客
TensorFlow函数教程:tf.io.parse_single_example_w3cschool
Tensorflow之TFRecord的原理和使用心得 - 知乎 (zhihu.com)

def load_dataset(filenames, labeled=True, ordered=False):dataset = tf.data.TFRecordDataset(filenames)dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)return dataset
tf.data.TFRecordDataset(filenames)

该数据集以字节形式从文件中加载 TFRecord,与写入时完全相同。 TFRecordDataset 本身不进行任何解析或解码。可以通过在 TFRecordDataset 之后应用 Dataset.map 转换来完成解析和解码。

参考链接:

tensorfow学习(一) ——tf.data.TFRecordDataset的使用-CSDN博客
TensorFlow2.0 TFrecord数据集的写入、读取和训练示例详解_tensorflow将图片数据写入tfrecord-CSDN博客
tensorflow入门:tfrecord 和tf.data.TFRecordDataset-CSDN博客
TFRecord + Dataset 进行数据的写入和读取 - 知乎 (zhihu.com)TensorFlow - tf.data.TFRecordDataset (runebook.dev)

创建generator

为了创建我们的generator,首先定义downsample和upsample方法。downsample,顾名思义,通过步幅减少图像的2D维度。upsample与downsample相反,增加图像的尺寸。Conv2DTranspose基本上与Conv2D层相反。

initializer = tf.random_normal_initializer(0., 0.02)

生成一组符合标准正态分布的Tensor的初始化器,类似的也可以初始化成别的形式(按照逻辑来说可能normal distribution并不是一个最好的选择,but anyway,既然GAN都能train起来,我更愿意相信……这玩意儿就是炼丹)

参考链接:

tensorflow和pytorch中的参数初始化调用方法-CSDN博客
tf.random_normal_initializer:TensorFlow初始化器_w3cschool
Tensorflow API——tf.random_normal_initializer_python中tf.random_normal_initializer什么意思-CSDN博客

gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

初始化定义了设置 Keras 各层权重随机初始值的方法。

按照正态分布生成随机张量的初始化器。

参数

  • mean: 一个 Python 标量或者一个标量张量。要生成的随机值的平均数。
  • stddev: 一个 Python 标量或者一个标量张量。要生成的随机值的标准差。
  • seed: 一个 Python 整数。用于设置随机数种子。

参考链接:

初始化 Initializers - Keras 中文文档 (kldivergence.github.io)
Layer weight initializers (keras.io)
Keras教学(6):Keras的初始化Initializers,看这一篇就够了_bias_initializer": {"module": "keras.initializers"-CSDN博客

result = keras.Sequential()
result.add(layers.Conv2DTranspose(filters, size, strides=2,padding='same',kernel_initializer=initializer,use_bias=False))

将图像恢复到原来的尺寸(上采样)->实现图像从小分辨率到大分辨率映射的操作。(有很多上采样的方法,反卷积只是其中的一种方法)

参考链接:

深度卷积生成对抗网络  |  TensorFlow Core (google.cn)
Conv2DTranspose layer (keras.io)
反卷积操作Conv2DTranspose-CSDN博客

有了upsampling和downsampling方法之后就可以创建generator了。在这里采用了skip的方法,提到这样是为了缓解梯度消失(一些resnet开始在脑海里旋转)

layers.Concatenate()([x, skip])

Concatenates a list of inputs.

It takes as input a list of tensors, all of the same shape except for the concatenation axis, and returns a single tensor that is the concatenation of all inputs.

简要来说就是把两个Tensor拼起来(所以也不算resnet死灰复燃×)按照axis取值的不同决定拼接的方式。如果没有指定axis的值,default=-1,就是从倒数第一个维度进行拼接。

创建discriminator

discriminator接收输入图像并将其分类为真实或虚假(生成)。鉴别器不是输出单个节点,而是输出一个较小的2D图像,其中像素值较高表示真实分类,像素值较低表示虚假分类。

with strategy.scope():monet_generator = Generator() # transforms photos to Monet-esque paintingsphoto_generator = Generator() # transforms Monet paintings to be more like photosmonet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintingsphoto_discriminator = Discriminator() # differentiates real photos and generated photos

从最开始的cycleGAN的图片可以看出,对于generator和discriminator,是双向的,所以两个方向都需要定义。

因为此刻generator尚未训练,所以自然生成出来的图片只能说是皇帝的新图×

创建CycleGAN

在训练步骤中,模型将照片转换为莫奈的画作,然后再转换回照片。原始照片和经过两次变换的照片之间的区别是循环一致性损失。我们希望原始照片和经过两次变换的照片彼此相似。

简要来说,为了好理解,就是auto-encoder

具体请参考文章开头处的链接

定义损失函数

with strategy.scope():def discriminator_loss(real, generated):real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)total_disc_loss = real_loss + generated_lossreturn total_disc_loss * 0.5

将真实图像比作1的矩阵,将假图像比作0的矩阵。完美的discriminator将为真实图像输出所有的1,为假图像输出所有的0。discriminator损耗输出实际损耗和生成损耗的平均值。

tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

计算真实标签和预测标签之间的交叉熵损失。将这种交叉熵损失用于二值(0或1)分类应用程序。

参考链接:

Probabilistic losses (keras.io)
TF2.0—tf.keras.losses.BinaryCrossentropy-CSDN博客
tf.keras.losses.BinaryCrossentropy函数-CSDN博客

这里延伸一下:

为什么分类要用交叉熵而不用MSE

比较直觉的解释可以是,假设我们有100个类别,假设类别1被分类成类别2,这与类别1被分成类别100其实是一样的,都是分错了,但是MSE就会觉得分成类别100错的更离谱,这是显然不合理的。

非直觉的理由(不好算)

  1. MSE作为分类的损失函数会有梯度消失的问题。
  2. MSE是非凸的,存在很多局部极小值点。

请参考链接:

分类为什么用CE而不是MSE - 知乎 (zhihu.com)

为什么分类问题不能使用mse损失函数_为什么分类不用mse-CSDN博客

训练和可视化

——————————————————————————

后续就是会用PyTorch自己写写,看情况放不放链接吧

这篇关于KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/634446

相关文章

Java中StopWatch的使用示例详解

《Java中StopWatch的使用示例详解》stopWatch是org.springframework.util包下的一个工具类,使用它可直观的输出代码执行耗时,以及执行时间百分比,这篇文章主要介绍... 目录stopWatch 是org.springframework.util 包下的一个工具类,使用它

Spring Boot 3.4.3 基于 Spring WebFlux 实现 SSE 功能(代码示例)

《SpringBoot3.4.3基于SpringWebFlux实现SSE功能(代码示例)》SpringBoot3.4.3结合SpringWebFlux实现SSE功能,为实时数据推送提供... 目录1. SSE 简介1.1 什么是 SSE?1.2 SSE 的优点1.3 适用场景2. Spring WebFlu

springboot security快速使用示例详解

《springbootsecurity快速使用示例详解》:本文主要介绍springbootsecurity快速使用示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝... 目录创www.chinasem.cn建spring boot项目生成脚手架配置依赖接口示例代码项目结构启用s

java之Objects.nonNull用法代码解读

《java之Objects.nonNull用法代码解读》:本文主要介绍java之Objects.nonNull用法代码,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录Java之Objects.nonwww.chinasem.cnNull用法代码Objects.nonN

golang 日志log与logrus示例详解

《golang日志log与logrus示例详解》log是Go语言标准库中一个简单的日志库,本文给大家介绍golang日志log与logrus示例详解,感兴趣的朋友一起看看吧... 目录一、Go 标准库 log 详解1. 功能特点2. 常用函数3. 示例代码4. 优势和局限二、第三方库 logrus 详解1.

SpringBoot实现MD5加盐算法的示例代码

《SpringBoot实现MD5加盐算法的示例代码》加盐算法是一种用于增强密码安全性的技术,本文主要介绍了SpringBoot实现MD5加盐算法的示例代码,文中通过示例代码介绍的非常详细,对大家的学习... 目录一、什么是加盐算法二、如何实现加盐算法2.1 加盐算法代码实现2.2 注册页面中进行密码加盐2.

python+opencv处理颜色之将目标颜色转换实例代码

《python+opencv处理颜色之将目标颜色转换实例代码》OpenCV是一个的跨平台计算机视觉库,可以运行在Linux、Windows和MacOS操作系统上,:本文主要介绍python+ope... 目录下面是代码+ 效果 + 解释转HSV: 关于颜色总是要转HSV的掩膜再标注总结 目标:将红色的部分滤

在C#中调用Python代码的两种实现方式

《在C#中调用Python代码的两种实现方式》:本文主要介绍在C#中调用Python代码的两种实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录C#调用python代码的方式1. 使用 Python.NET2. 使用外部进程调用 Python 脚本总结C#调

Redis 中的热点键和数据倾斜示例详解

《Redis中的热点键和数据倾斜示例详解》热点键是指在Redis中被频繁访问的特定键,这些键由于其高访问频率,可能导致Redis服务器的性能问题,尤其是在高并发场景下,本文给大家介绍Redis中的热... 目录Redis 中的热点键和数据倾斜热点键(Hot Key)定义特点应对策略示例数据倾斜(Data S

JavaScript Array.from及其相关用法详解(示例演示)

《JavaScriptArray.from及其相关用法详解(示例演示)》Array.from方法是ES6引入的一个静态方法,用于从类数组对象或可迭代对象创建一个新的数组实例,本文将详细介绍Array... 目录一、Array.from 方法概述1. 方法介绍2. 示例演示二、结合实际场景的使用1. 初始化二