超分辨重建——SRGAN网络训练自己数据集与推理测试(详细图文教程)

本文主要是介绍超分辨重建——SRGAN网络训练自己数据集与推理测试(详细图文教程),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

💪 专业从事且热爱图像处理,图像处理专栏更新如下👇:
📝《图像去噪》
📝《超分辨率重建》
📝《语义分割》
📝《风格迁移》
📝《目标检测》
📝《暗光增强》
📝《模型优化》
📝《模型实战部署》

😊总结不易,多多支持呀🌹感谢您的点赞👍收藏⭐评论✍️,您的三连是我持续更新的动力💖


在这里插入图片描述

目录

  • 一、SRGAN网络
    • 1.1 标题
    • 1.2 作者
    • 1.3 发表时间
    • 1.4 摘要
    • 1.5 主要内容
      • 1.5.1 生成对抗网络架构
      • 1.5.2 损失函数
      • 1.5.3 实验结果
    • 1.6 论文总结
  • 二、源码包准备
  • 三、环境准备
    • 3.1 报错:AttributeError: module 'torch' has no attribute 'compile'
    • 3.2 报错:RuntimeError: Windows not yet supported for torch.compile
  • 四、数据集准备
  • 五、训练
    • 5.1 预训练权重下载
    • 5.2 配置文件参数修改
    • 5.3 启动训练
    • 5.4 实时可视化训练过程损失函数走势
    • 5.5 训练结果
  • 六、测试
    • 6.1 测试配置文件修改
    • 6.2 启动测试
  • 七、推理速度
    • 7.1 GPU
    • 7.2 CPU
  • 八、超分效果展示
  • 九、总结

一、SRGAN网络

1.1 标题

“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”

1.2 作者

Christian Ledig, Lucas Theis, Ferenc Huszár, Jose Caballero, Andrew Cunningham, Alejandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, Wenzhe Shi

1.3 发表时间

2017年

1.4 摘要

SRGAN通过利用生成对抗网络(GAN)来实现单图像超分辨率重建。传统的方法如基于均方误差(MSE)的优化通常会导致图像平滑且缺乏细节,而SRGAN通过引入感知损失函数(perceptual loss),使得重建的图像不仅在像素级别上更接近高分辨率图像,而且在感知质量上也更加接近真实图像。

1.5 主要内容

1.5.1 生成对抗网络架构

生成器(Generator):采用残差网络(ResNet)结构,能够有效地学习从低分辨率图像到高分辨率图像的映射。
判别器(Discriminator):判别器的任务是区分生成的高分辨率图像和真实的高分辨率图像。通过对抗训练,生成器能够学习生成更加逼真的图像。

1.5.2 损失函数

内容损失(Content Loss):利用VGG网络提取的特征来计算生成图像和真实图像之间的差异。
对抗损失(Adversarial Loss):来自GAN的对抗训练,使得生成器能够骗过判别器,从而生成更加逼真的图像。
感知损失(Perceptual Loss):

感知损失结合内容损失和对抗损失,旨在提高重建图像的感知质量,使其在视觉上更接近真实图像。

1.5.3 实验结果

SRGAN在多种数据集上进行了测试,结果表明,与传统方法(如基于MSE的方法)相比,SRGAN生成的图像在感知质量上有显著提升。在用户研究中,SRGAN生成的图像被评为更接近真实图像。

1.6 论文总结

SRGAN通过生成对抗网络和感知损失函数的结合,显著提升了单图像超分辨率重建的效果。该方法不仅在像素级别上达到了更高的精度,同时在视觉感知上也大幅提升,生成的图像更加逼真,细节更加丰富。

二、源码包准备

本配套教程源码包中已经下载好了测试模型和预训练模型,部分训练集和测试集。源码包获取方法文章末扫码到公众号「视觉研坊」中回复关键字:超分辨率重建SRGAN。获取下载链接。

Pytorch版的官网源码包地址:SRGAN

论文地址:论文

三、环境准备

下面是我自己训练和测试的环境,仅供参考,其它版本也行。

在这里插入图片描述

3.1 报错:AttributeError: module ‘torch’ has no attribute ‘compile’

该报错是因为yTorch 版本不支持 torch.compile 方法。这种方法是在 PyTorch 2.0 版本中引入的,而我使用的Pytorch为1.12版本

在windows电脑上我安装了2.0.1版Pytorch,继续报错。

3.2 报错:RuntimeError: Windows not yet supported for torch.compile

安装了2.0.1版本Pytorch,见下:

在这里插入图片描述

报错见下:

在这里插入图片描述

报错原因:在 PyTorch 2.0 中,torch.compile 目前不支持在 Windows 上运行。

解决办法:网络训练过程不加速,把compile关闭,具体见下:

在这里插入图片描述

关闭后,后续训练和测试,我继续在之前Pytotch1.12.1版本上操作。

解决该问题还有中方式使用 torch.jit.trace 替代torch.compile,后续没调试。

四、数据集准备

直接运行代码会自动下载数据集,某些情况下会下载中断,而且很慢,可以把数据集下载链接拷贝到迅雷中,速度较快,找数据集链接的方法见下,原论文中的数据集下载链接为:https://huggingface.co/datasets/goodfellowliu/SRGAN_ImageNet/resolve/main/SRGAN_ImageNet.zip

在这里插入图片描述

数据集下载好后,先通过split_images.py脚本将各种分辨率的图像裁剪为统一尺寸图片并保存到指定路径中。关于split_images.py脚本的具体用法,以及数据集的样子参考另外一篇博文:高分辨率图像分割成大小均匀图像

测试集的路径见下:

在这里插入图片描述

五、训练

源码中有net网络和gan网络,我主要讲解gan网络的训练和测试,net网络的训练和测试类同。源码中有2倍,4倍,8倍超分,本教程主要讲解4倍超分,其它超分类同。

5.1 预训练权重下载

直接运行脚本,代码也会自动下载预训练模型,如果自动下载出了问题,去下面文件中找到预训练模型下载链接:

在这里插入图片描述

自己下载的模型权重文件,存放到results\pretrained_models路径中:

在这里插入图片描述

5.2 配置文件参数修改

下面是常用参数,其它参数学生根据自己情况自行修改。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

5.3 启动训练

gan网络训练的主脚本为train_gan.py,在此脚本中修改训练用的配置文件路径,见下:

在这里插入图片描述
直接运行train_gan.py脚本开始训练:

在这里插入图片描述

部分训练过程见下:

在这里插入图片描述

5.4 实时可视化训练过程损失函数走势

在终端使用下面命令启动tensorboard实时可视化训练过程损失函数走势:

tensorboard --logdir=samples/logs/SRGAN_x4-SRGAN_ImageNet --port 6007

在这里插入图片描述

具体的可视化走势图见下:

在这里插入图片描述

5.5 训练结果

训练过程的模型权重文件自动保存到results\SRGAN_x4-SRGAN_ImageNet路径下:

在这里插入图片描述

训练过程中每一轮的模型权重文件保存到samples\SRGAN_x4-SRGAN_ImageNet路径下:

在这里插入图片描述

六、测试

6.1 测试配置文件修改

下面参数学者根据自己情况调整修改。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

6.2 启动测试

在这里插入图片描述

将上面required设置为False后,直接运行test.py脚本:

在这里插入图片描述

输出的评价指标如下:

在这里插入图片描述

测试结果保存到result_images\SRGAN_x4-SRGAN_ImageNet-Set14路径下:

在这里插入图片描述

七、推理速度

7.1 GPU

GPU测试环境:Nvidia GeForce RTX 3050。

120*90图像超分4倍 GPU平均推理时间:7.69ms/fps。

在这里插入图片描述

7.2 CPU

12th Gen Intel® Core™ i7-12700H 2.30 GHz。

下面是120*90图像超分4倍,CPU平均推理时间:302.31ms/fps。

在这里插入图片描述

八、超分效果展示

下面左图为bicubic上采样4倍,中间为原图,右图为SRGAN网络超分4倍结果图。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

九、总结

以上就是超分辨重建SRGAN网络训练自己数据集与推理测试详细过程,超分效果与我超分专栏里的其他网络做对比。

感谢您阅读到最后!关注公众号「视觉研坊」,获取干货教程、实战案例、技术解答、行业资讯!

这篇关于超分辨重建——SRGAN网络训练自己数据集与推理测试(详细图文教程)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1028107

相关文章

javaScript在表单提交时获取表单数据的示例代码

《javaScript在表单提交时获取表单数据的示例代码》本文介绍了五种在JavaScript中获取表单数据的方法:使用FormData对象、手动提取表单数据、使用querySelector获取单个字... 方法 1:使用 FormData 对象FormData 是一个方便的内置对象,用于获取表单中的键值

MySQL zip安装包配置教程

《MySQLzip安装包配置教程》这篇文章详细介绍了如何使用zip安装包在Windows11上安装MySQL8.0,包括下载、解压、配置环境变量、初始化数据库、安装服务以及更改密码等步骤,感兴趣的朋... 目录mysql zip安装包配置教程1、下载zip安装包:2、安装2.1 解压zip包到安装目录2.2

Java集合中的List超详细讲解

《Java集合中的List超详细讲解》本文详细介绍了Java集合框架中的List接口,包括其在集合中的位置、继承体系、常用操作和代码示例,以及不同实现类(如ArrayList、LinkedList和V... 目录一,List的继承体系二,List的常用操作及代码示例1,创建List实例2,增加元素3,访问元

Rust中的BoxT之堆上的数据与递归类型详解

《Rust中的BoxT之堆上的数据与递归类型详解》本文介绍了Rust中的BoxT类型,包括其在堆与栈之间的内存分配,性能优势,以及如何利用BoxT来实现递归类型和处理大小未知类型,通过BoxT,Rus... 目录1. Box<T> 的基础知识1.1 堆与栈的分工1.2 性能优势2.1 递归类型的问题2.2

Java使用Tesseract-OCR实战教程

《Java使用Tesseract-OCR实战教程》本文介绍了如何在Java中使用Tesseract-OCR进行文本提取,包括Tesseract-OCR的安装、中文训练库的配置、依赖库的引入以及具体的代... 目录Java使用Tesseract-OCRTesseract-OCR安装配置中文训练库引入依赖代码实

SpringBoot整合easy-es的详细过程

《SpringBoot整合easy-es的详细过程》本文介绍了EasyES,一个基于Elasticsearch的ORM框架,旨在简化开发流程并提高效率,EasyES支持SpringBoot框架,并提供... 目录一、easy-es简介二、实现基于Spring Boot框架的应用程序代码1.添加相关依赖2.添

Python使用Pandas对比两列数据取最大值的五种方法

《Python使用Pandas对比两列数据取最大值的五种方法》本文主要介绍使用Pandas对比两列数据取最大值的五种方法,包括使用max方法、apply方法结合lambda函数、函数、clip方法、w... 目录引言一、使用max方法二、使用apply方法结合lambda函数三、使用np.maximum函数

SpringBoot中整合RabbitMQ(测试+部署上线最新完整)的过程

《SpringBoot中整合RabbitMQ(测试+部署上线最新完整)的过程》本文详细介绍了如何在虚拟机和宝塔面板中安装RabbitMQ,并使用Java代码实现消息的发送和接收,通过异步通讯,可以优化... 目录一、RabbitMQ安装二、启动RabbitMQ三、javascript编写Java代码1、引入

Nginx设置连接超时并进行测试的方法步骤

《Nginx设置连接超时并进行测试的方法步骤》在高并发场景下,如果客户端与服务器的连接长时间未响应,会占用大量的系统资源,影响其他正常请求的处理效率,为了解决这个问题,可以通过设置Nginx的连接... 目录设置连接超时目的操作步骤测试连接超时测试方法:总结:设置连接超时目的设置客户端与服务器之间的连接

Java调用DeepSeek API的最佳实践及详细代码示例

《Java调用DeepSeekAPI的最佳实践及详细代码示例》:本文主要介绍如何使用Java调用DeepSeekAPI,包括获取API密钥、添加HTTP客户端依赖、创建HTTP请求、处理响应、... 目录1. 获取API密钥2. 添加HTTP客户端依赖3. 创建HTTP请求4. 处理响应5. 错误处理6.