KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读

本文主要是介绍KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本博文阅读的代码来自于I’m Something of a Painter Myself | Kaggle倾情推荐:

Monet CycleGAN Tutorial | Kaggle

数据集说明

I’m Something of a Painter Myself | Kaggle

Files

  • monet_jpg - 300 Monet paintings sized 256x256 in JPEG format
  • monet_tfrec - 300 Monet paintings sized 256x256 in TFRecord format
  • photo_jpg - 7028 photos sized 256x256 in JPEG format
  • photo_tfrec - 7028 photos sized 256x256 in TFRecord format

简单介绍一下,就是有两种类型的数据提供使用,一个是JPEG格式,一个是TFRecord,训练集的size是300,测试集的size是7028。并且每张图片的大小都是256×256。

代码阅读

首先是一些说明,这个代码使用的是TensorFlow,所以也就大概看看,后面会搬家到PyTorch写写看。使用的是CycleGAN,这个很合理,因为这里是无监督学习,不过GAN的种类有超多哎,也许会有更好的GAN可以选择呢?anyway,CycleGAN也是很经典的方法啦。

先放一张PPT:

加载数据

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))
tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))

tf.io.gfile.glob(pattern): Returns a list of files that match the given pattern(s).

查找匹配pattern的文件并以列表的形式返回,pattern可以是一个具体的文件名,也可以是包含通配符的正则表达式。

参考链接:

tf.io.gfile.glob  |  TensorFlow v2.15.0.post1 (google.cn)
TensorFlow函数教程:tf.io.gfile.glob_w3cschool
Tensorflow 2.0 gfile 文件操作 - 知乎 (zhihu.com)
tf.io.gfile.glob 遍历文件-CSDN博客
TensorFlow函数:tf.io.gfile.glob_tf.io.gfile.remove函数参数-CSDN博客

def decode_image(image):image = tf.image.decode_jpeg(image, channels=3)image = (tf.cast(image, tf.float32) / 127.5) - 1image = tf.reshape(image, [*IMAGE_SIZE, 3])return image
tf.image.decode_jpeg(image, channels=3)

将JPEG编码的图像解码为unit8的Tensor,channels取3表示返回的是RGB图像,channels取1表示返回的是灰度图像,channels取0表示使用JPEG编码图像中的通道数量。

参考链接:

加载和预处理图像  |  TensorFlow Core (google.cn)
TensorFlow函数:tf.image.decode_jpeg_w3cschool
tf.image.decode_jpeg函数与tf.image.encode_jpeg函数用法-CSDN博客

tf.cast(image, tf.float32)

将前面得到的Tensor数据类型从unit8改到float32

参考链接:

Tensorflow中 tf.cast()的用法_tensorflow.cast-CSDN博客
tensorflow——tf.cast()详解_tensorflow的cast-CSDN博客
tf.cast - TensorFlow Python - W3cubDocs

tf.reshape(image, [*IMAGE_SIZE, 3])

对image(也就是前面处理的Tensor)进行维度的调整。比如:

参考链接:

tf.reshape函数用法&理解-CSDN博客
【tensorflow】tf.reshape函数说明:重塑张量_tensorflow reshape 变大-CSDN博客
TensorFlow:使用tf.reshape函数重塑张量_w3cschool
tf.reshape(x, [-1, 28, 28, 1])_reshape((-1, 28, 28, 1)-CSDN博客
Python的reshape的用法:reshape(1,-1)-CSDN博客

def read_tfrecord(example):tfrecord_format = {"image_name": tf.io.FixedLenFeature([], tf.string),"image": tf.io.FixedLenFeature([], tf.string),"target": tf.io.FixedLenFeature([], tf.string)}example = tf.io.parse_single_example(example, tfrecord_format)image = decode_image(example['image'])return image
tf.io.FixedLenFeature([], tf.string)

解析每个输入样本的每一列数据

参考链接:

TFRecord 中 FixedLenFeature、VarLenFeature、FixedLenSequenceFeature 说明_tf.fixedlenfeature-CSDN博客
Tensorflow2.0之TFRecord文件的写入与读取_tensorflow2 使用tfrecord-CSDN博客
TensorFlow函数教程:tf.io.FixedLenFeature_w3cschool

tf.io.parse_single_example(example, tfrecord_format)

输入一个Tensor,输出一个dict

使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

参考链接:

tensorflow2.0 环境下的tfrecord读写及tf.io.parse_example和tf.io.parse_single_example的区别-CSDN博客
TensorFlow2.0 TFrecord数据集的写入、读取和训练示例详解_tensorflow将图片数据写入tfrecord-CSDN博客
TensorFlow函数教程:tf.io.parse_single_example_w3cschool
Tensorflow之TFRecord的原理和使用心得 - 知乎 (zhihu.com)

def load_dataset(filenames, labeled=True, ordered=False):dataset = tf.data.TFRecordDataset(filenames)dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)return dataset
tf.data.TFRecordDataset(filenames)

该数据集以字节形式从文件中加载 TFRecord,与写入时完全相同。 TFRecordDataset 本身不进行任何解析或解码。可以通过在 TFRecordDataset 之后应用 Dataset.map 转换来完成解析和解码。

参考链接:

tensorfow学习(一) ——tf.data.TFRecordDataset的使用-CSDN博客
TensorFlow2.0 TFrecord数据集的写入、读取和训练示例详解_tensorflow将图片数据写入tfrecord-CSDN博客
tensorflow入门:tfrecord 和tf.data.TFRecordDataset-CSDN博客
TFRecord + Dataset 进行数据的写入和读取 - 知乎 (zhihu.com)TensorFlow - tf.data.TFRecordDataset (runebook.dev)

创建generator

为了创建我们的generator,首先定义downsample和upsample方法。downsample,顾名思义,通过步幅减少图像的2D维度。upsample与downsample相反,增加图像的尺寸。Conv2DTranspose基本上与Conv2D层相反。

initializer = tf.random_normal_initializer(0., 0.02)

生成一组符合标准正态分布的Tensor的初始化器,类似的也可以初始化成别的形式(按照逻辑来说可能normal distribution并不是一个最好的选择,but anyway,既然GAN都能train起来,我更愿意相信……这玩意儿就是炼丹)

参考链接:

tensorflow和pytorch中的参数初始化调用方法-CSDN博客
tf.random_normal_initializer:TensorFlow初始化器_w3cschool
Tensorflow API——tf.random_normal_initializer_python中tf.random_normal_initializer什么意思-CSDN博客

gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

初始化定义了设置 Keras 各层权重随机初始值的方法。

按照正态分布生成随机张量的初始化器。

参数

  • mean: 一个 Python 标量或者一个标量张量。要生成的随机值的平均数。
  • stddev: 一个 Python 标量或者一个标量张量。要生成的随机值的标准差。
  • seed: 一个 Python 整数。用于设置随机数种子。

参考链接:

初始化 Initializers - Keras 中文文档 (kldivergence.github.io)
Layer weight initializers (keras.io)
Keras教学(6):Keras的初始化Initializers,看这一篇就够了_bias_initializer": {"module": "keras.initializers"-CSDN博客

result = keras.Sequential()
result.add(layers.Conv2DTranspose(filters, size, strides=2,padding='same',kernel_initializer=initializer,use_bias=False))

将图像恢复到原来的尺寸(上采样)->实现图像从小分辨率到大分辨率映射的操作。(有很多上采样的方法,反卷积只是其中的一种方法)

参考链接:

深度卷积生成对抗网络  |  TensorFlow Core (google.cn)
Conv2DTranspose layer (keras.io)
反卷积操作Conv2DTranspose-CSDN博客

有了upsampling和downsampling方法之后就可以创建generator了。在这里采用了skip的方法,提到这样是为了缓解梯度消失(一些resnet开始在脑海里旋转)

layers.Concatenate()([x, skip])

Concatenates a list of inputs.

It takes as input a list of tensors, all of the same shape except for the concatenation axis, and returns a single tensor that is the concatenation of all inputs.

简要来说就是把两个Tensor拼起来(所以也不算resnet死灰复燃×)按照axis取值的不同决定拼接的方式。如果没有指定axis的值,default=-1,就是从倒数第一个维度进行拼接。

创建discriminator

discriminator接收输入图像并将其分类为真实或虚假(生成)。鉴别器不是输出单个节点,而是输出一个较小的2D图像,其中像素值较高表示真实分类,像素值较低表示虚假分类。

with strategy.scope():monet_generator = Generator() # transforms photos to Monet-esque paintingsphoto_generator = Generator() # transforms Monet paintings to be more like photosmonet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintingsphoto_discriminator = Discriminator() # differentiates real photos and generated photos

从最开始的cycleGAN的图片可以看出,对于generator和discriminator,是双向的,所以两个方向都需要定义。

因为此刻generator尚未训练,所以自然生成出来的图片只能说是皇帝的新图×

创建CycleGAN

在训练步骤中,模型将照片转换为莫奈的画作,然后再转换回照片。原始照片和经过两次变换的照片之间的区别是循环一致性损失。我们希望原始照片和经过两次变换的照片彼此相似。

简要来说,为了好理解,就是auto-encoder

具体请参考文章开头处的链接

定义损失函数

with strategy.scope():def discriminator_loss(real, generated):real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)total_disc_loss = real_loss + generated_lossreturn total_disc_loss * 0.5

将真实图像比作1的矩阵,将假图像比作0的矩阵。完美的discriminator将为真实图像输出所有的1,为假图像输出所有的0。discriminator损耗输出实际损耗和生成损耗的平均值。

tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

计算真实标签和预测标签之间的交叉熵损失。将这种交叉熵损失用于二值(0或1)分类应用程序。

参考链接:

Probabilistic losses (keras.io)
TF2.0—tf.keras.losses.BinaryCrossentropy-CSDN博客
tf.keras.losses.BinaryCrossentropy函数-CSDN博客

这里延伸一下:

为什么分类要用交叉熵而不用MSE

比较直觉的解释可以是,假设我们有100个类别,假设类别1被分类成类别2,这与类别1被分成类别100其实是一样的,都是分错了,但是MSE就会觉得分成类别100错的更离谱,这是显然不合理的。

非直觉的理由(不好算)

  1. MSE作为分类的损失函数会有梯度消失的问题。
  2. MSE是非凸的,存在很多局部极小值点。

请参考链接:

分类为什么用CE而不是MSE - 知乎 (zhihu.com)

为什么分类问题不能使用mse损失函数_为什么分类不用mse-CSDN博客

训练和可视化

——————————————————————————

后续就是会用PyTorch自己写写,看情况放不放链接吧

这篇关于KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/634446

相关文章

C++使用栈实现括号匹配的代码详解

《C++使用栈实现括号匹配的代码详解》在编程中,括号匹配是一个常见问题,尤其是在处理数学表达式、编译器解析等任务时,栈是一种非常适合处理此类问题的数据结构,能够精确地管理括号的匹配问题,本文将通过C+... 目录引言问题描述代码讲解代码解析栈的状态表示测试总结引言在编程中,括号匹配是一个常见问题,尤其是在

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.

Android 悬浮窗开发示例((动态权限请求 | 前台服务和通知 | 悬浮窗创建 )

《Android悬浮窗开发示例((动态权限请求|前台服务和通知|悬浮窗创建)》本文介绍了Android悬浮窗的实现效果,包括动态权限请求、前台服务和通知的使用,悬浮窗权限需要动态申请并引导... 目录一、悬浮窗 动态权限请求1、动态请求权限2、悬浮窗权限说明3、检查动态权限4、申请动态权限5、权限设置完毕后

在 Spring Boot 中使用 @Autowired和 @Bean注解的示例详解

《在SpringBoot中使用@Autowired和@Bean注解的示例详解》本文通过一个示例演示了如何在SpringBoot中使用@Autowired和@Bean注解进行依赖注入和Bean... 目录在 Spring Boot 中使用 @Autowired 和 @Bean 注解示例背景1. 定义 Stud

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

oracle DBMS_SQL.PARSE的使用方法和示例

《oracleDBMS_SQL.PARSE的使用方法和示例》DBMS_SQL是Oracle数据库中的一个强大包,用于动态构建和执行SQL语句,DBMS_SQL.PARSE过程解析SQL语句或PL/S... 目录语法示例注意事项DBMS_SQL 是 oracle 数据库中的一个强大包,它允许动态地构建和执行

Python中顺序结构和循环结构示例代码

《Python中顺序结构和循环结构示例代码》:本文主要介绍Python中的条件语句和循环语句,条件语句用于根据条件执行不同的代码块,循环语句用于重复执行一段代码,文章还详细说明了range函数的使... 目录一、条件语句(1)条件语句的定义(2)条件语句的语法(a)单分支 if(b)双分支 if-else(

在不同系统间迁移Python程序的方法与教程

《在不同系统间迁移Python程序的方法与教程》本文介绍了几种将Windows上编写的Python程序迁移到Linux服务器上的方法,包括使用虚拟环境和依赖冻结、容器化技术(如Docker)、使用An... 目录使用虚拟环境和依赖冻结1. 创建虚拟环境2. 冻结依赖使用容器化技术(如 docker)1. 创

Python中Markdown库的使用示例详解

《Python中Markdown库的使用示例详解》Markdown库是一个用于处理Markdown文本的Python工具,这篇文章主要为大家详细介绍了Markdown库的具体使用,感兴趣的... 目录一、背景二、什么是 Markdown 库三、如何安装这个库四、库函数使用方法1. markdown.mark

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意