深度学习实践:图像去雨网络实现Pytoch

2023-11-11 15:11

本文主要是介绍深度学习实践:图像去雨网络实现Pytoch,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

第二集教程链接:http://t.csdn.cn/QxpgD (更详细)

      本文引用 听 风、的博客 图像去雨:超详细手把手写 pytorch 实现代码(带注释)的网络框架,并进行了优化,主要加入了BatchNormalized模块。优化了代码整体框架和书写规范,加入了更多注释。

代码链接:

Kaggle:Derain_Study | Kaggle

Github:Learn_Pytorch/derain-study.ipynb at main · DLee0102/Learn_Pytorch (github.com)

        改进后的代码加入了验证集以观察训练的模型是否过拟合。同时使用了tqdm工具包,方便观察训练进度。在保存模型方面使用了更高效的方法,即保存在验证集上损失最小的模型。

354b4a5cc1ce42178a8b6262567421fd.png

 

 

        数据集采用的是Kaggle上的JRDR - Deraining Dataset的Light数据集,使用了更优化的dataset方法,以使input和label的图片能准确匹配。

4f19311e363f476b9afec88681d7a237.png

 

import os
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
import re'''
Dataset for Training.
'''
class MyTrainDataset(Dataset):def __init__(self, input_path, label_path):self.input_path = input_pathself.input_files = os.listdir(input_path)self.label_path = label_pathself.label_files = os.listdir(label_path)self.transforms = transforms.Compose([transforms.CenterCrop([64, 64]), transforms.ToTensor(),])def __len__(self):return len(self.input_files)def __getitem__(self, index):label_image_path = os.path.join(self.label_path, self.label_files[index])label_image = Image.open(label_image_path).convert('RGB')'''Ensure input and label are in couple.'''temp = self.label_files[index][:-4]self.input_files[index] = temp + 'x2.png'input_image_path = os.path.join(self.input_path, self.input_files[index])input_image = Image.open(input_image_path).convert('RGB')input = self.transforms(input_image)label = self.transforms(label_image)return input, label'''
Dataset for testing.
'''
class MyValidDataset(Dataset):def __init__(self, input_path, label_path):self.input_path = input_pathself.input_files = os.listdir(input_path)self.label_path = label_pathself.label_files = os.listdir(label_path)self.transforms = transforms.Compose([transforms.CenterCrop([64, 64]), transforms.ToTensor(),])def __len__(self):return len(self.input_files)def __getitem__(self, index):label_image_path = os.path.join(self.label_path, self.label_files[index])label_image = Image.open(label_image_path).convert('RGB')temp = self.label_files[index][:-4]self.input_files[index] = temp + 'x2.png'input_image_path = os.path.join(self.input_path, self.input_files[index])input_image = Image.open(input_image_path).convert('RGB')input = self.transforms(input_image)label = self.transforms(label_image)return input, label
'''
Residual_Network with BatchNormalized.
'''
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv0 = nn.Sequential(nn.Conv2d(6, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU())self.res_conv1 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU())self.res_conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU())self.res_conv3 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU())self.res_conv4 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU())self.res_conv5 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.BatchNorm2d(32),nn.ReLU())self.conv = nn.Sequential(nn.Conv2d(32, 3, 3, 1, 1),)def forward(self, input):x = inputfor i in range(6):  # Won't change the number of parameters'''Different from Classification.'''x = torch.cat((input, x), 1)x = self.conv0(x)x = F.relu(self.res_conv1(x) + x)x = F.relu(self.res_conv2(x) + x)x = F.relu(self.res_conv3(x) + x)x = F.relu(self.res_conv4(x) + x)x = F.relu(self.res_conv5(x) + x)x = self.conv(x)x = x + inputreturn x
'''
Check the number of GPU.
'''
print("Let's use", torch.cuda.device_count(), "GPUs!")
'''
Path of Dataset.
'''
input_path = "../input/jrdr-deraining-dataset/JRDR/rain_data_train_Light/rain"
label_path = "../input/jrdr-deraining-dataset/JRDR/rain_data_train_Light/norain"
valid_input_path = '../input/jrdr-deraining-dataset/JRDR/rain_data_test_Light/rain/X2'
valid_label_path = '../input/jrdr-deraining-dataset/JRDR/rain_data_test_Light/norain''''
Check the device.
'''
device = 'cpu'
if torch.cuda.is_available():device = 'cuda''''
Move the Network to the CUDA.
'''
net = Net().to(device)'''
Hyper Parameters.TODO: fine-tuning.
'''
learning_rate = 1e-3
batch_size = 50
epoch = 100
patience = 30
stale = 0
best_valid_loss = 10000'''
Prepare for plt.
'''
Loss_list = []
Valid_Loss_list = []'''
Define optimizer and Loss Function.
'''
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
loss_f = nn.MSELoss()'''
Check the model.
'''
if os.path.exists('./model.pth'): print('Continue train with last model...')net.load_state_dict(torch.load('./model.pth'))
else: print("Restart...")'''
Prepare DataLoaders.Attension:'pin_numbers=True' can accelorate CUDA computing.
'''
dataset_train = MyTrainDataset(input_path, label_path)
dataset_valid = MyValidDataset(valid_input_path, valid_label_path)
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, pin_memory=True)'''
START Training ...
'''
for i in range(epoch):
# ---------------Train----------------net.train()train_losses = []'''tqdm is a toolkit for progress bar.'''for batch in tqdm(train_loader):inputs, labels = batchoutputs = net(inputs.to(device))loss = loss_f(outputs, labels.to(device))optimizer.zero_grad()loss.backward()'''Avoid grad to be too BIG.'''grad_norm = nn.utils.clip_grad_norm_(net.parameters(), max_norm=10)optimizer.step()'''Attension:We need set 'loss.item()' to turn Tensor into Numpy, or plt will not work.'''train_losses.append(loss.item())train_loss = sum(train_losses)Loss_list.append(train_loss)print(f"[ Train | {i + 1:03d}/{epoch:03d} ] loss = {train_loss:.5f}")# -------------Validation-------------
'''
Validation is a step to ensure training process is working.
You can also exploit Validation to see if your net work is overfitting.Firstly, you should set model.eval(), to ensure parameters not training.
'''net.eval()valid_losses = []for batch in tqdm(valid_loader):inputs, labels = batch'''Cancel gradient decent.'''with torch.no_grad():outputs = net(inputs.to(device))loss = loss_f(outputs, labels.to(device))valid_losses.append(loss.item())valid_loss = sum(valid_losses)Valid_Loss_list.append(valid_loss)print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] loss = {valid_loss:.5f}")'''Update Logs and save the best model.Patience is also checked.'''if valid_loss < best_valid_loss:print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] loss = {valid_loss:.5f} -> best")else:print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] loss = {valid_loss:.5f}")if valid_loss < best_valid_loss:print(f'Best model found at epoch {i+1}, saving model')torch.save(net.state_dict(), f'model_best.ckpt')best_valid_loss = valid_lossstale = 0else:stale += 1if stale > patience:print(f'No improvement {patience} consecutive epochs, early stopping.')break'''
Use plt to draw Loss curves.
'''
plt.figure(dpi=500)
x = range(epoch)
y = Loss_list
plt.plot(x, y, 'ro-', label='Train Loss')
plt.plot(range(epoch), Valid_Loss_list, 'bs-', label='Valid Loss')
plt.ylabel('Loss')
plt.xlabel('epochs')
plt.legend()
plt.show()

训练结果如下:(显示效果不太好)

ad7d16c58efd4d5d96ab18fb8bb35693.png

 test上实际去雨效果:

原图:

2d801216f83f43a398a996a7f6f2f537.png

 未加入BatchNormalize的效果:

4eddbb740144479dba93683c4f51bb87.jpeg

 加入BatchNormalize后的结果:

5e846e3a16c243a38c07aa3b14e49e9e.jpeg

 可以看到,同样训练论述的情况下,加入BatchNormalize后雨线数目明显减少

 

第二集教程链接:http://t.csdn.cn/QxpgD (更详细)

 

这篇关于深度学习实践:图像去雨网络实现Pytoch的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/390938

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

基于人工智能的图像分类系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 图像分类是计算机视觉中的一个重要任务,目标是自动识别图像中的对象类别。通过卷积神经网络(CNN)等深度学习技术,我们可以构建高效的图像分类系统,广泛应用于自动驾驶、医疗影像诊断、监控分析等领域。本文将介绍如何构建一个基于人工智能的图像分类系统,包括环境

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time