Gated Context Aggregation Network for Image Dehazing and Deraining(GCANet)

2023-11-11 04:36

本文主要是介绍Gated Context Aggregation Network for Image Dehazing and Deraining(GCANet),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1 总体概述

GCANet是端到端去雾的一篇代表性的文章,它摒弃以往使用手工设计的先验以及大气散射模型的使用,直接通过原始有雾图像估计出无雾图像J与有雾图像I之间的残差,图像恢复阶段直接使用网络输出的残差与输入有雾图像I之间的加和完成去雾过程。

文章本身最大的贡献:
1、借鉴并使用了平滑空洞卷积消除以往空洞卷积存在的网格伪影以及特征相关性不强的问题,提出了一个门限子网络,用于依据不同level的特征的权重进行特征加权融合
2、GCANet达到当前的SOTA,并且使用消融试验对不同模块重要性进行了分析
3、GCANet应用到去雨任务中依然获得了SOTA

2 灵感来源

之前的研究者利用扩张卷积来聚合上下文信息,可以获得更加细腻和准确的结果,主要原因是扩张卷积不损失分辨率,但是他也存在一些问题,比如网格伪影,远距离信息没有相关性。因此也有很多人去尝试改进上述问题,比如使用平滑空洞卷积来消除网格伪影;也有人使用不同level的特征图进行融合获得更好的去雾效果;也有使用使用gated fusion module ,但是它是直接使用原始图的拷贝而非中间所获取的特征图;GCANet借鉴了上述思想,使用了扩张卷积,借鉴了smooth 扩张卷积消除伪影,借鉴了融合特征图的思路,提出了一个门限子网络用于辨析不同level特征图的重要性

3 现有工作分析

去雾分为两种,其一是基于传统先验知识的去雾,其二是基于学习的方式, 区别就是第一种方案通过手工获取的先验知识在第二钟方案是通过学习获取

传统方案去雾
1、基于暗通道先验以及其对应的优化方案
2、最大对比度
3、颜色衰减先验等

深度学习去雾
1、使用端到端的深度学习方式,利用多尺度网络预测透射率图,但是透射率图估计的不准确导致去雾结果较差
2、将全局大气光值A以及透射率参数融合为一个参数,使用轻量级的网络进行预测
3、也有使用两个子网络分别预测全局大气光值A以及透射率参数,并依据大气散射模型进行图像去雾的

4 本文GCANet方法

在这里插入图片描述整体架构是:首先使用编码模块将输入的有雾图像编码为特征;接着通过聚合上下文信息以及融合不同level的特征强化编码特征(主要使用平滑空洞卷积以及特殊设计的门限子网络);最后使用一个解码网络将特征映射回原图空间,并加上原始图就可以获取最终的去雾图像

4.1 Smoothed Dilated Convolution

什么是网格伪影?
在这里插入图片描述
由上图可知,最右边的这一幅特征图中的红蓝绿黄色四种小点来自于之前特征层对应颜色的独立特征,特征之间没有交互,没有融合,导致最终获取的当前层的特征的之间没有相关性可言,造成局部信息丢失,这对于pixel_level的预测来说是极其致命的。

消除网格伪影有两种方式,第一种是在使用空洞卷积之前,使用共享可分离卷积先进行特征之间的融合;另一种方式是在卷积后特征整合之前,使用类似于shuffleNet一样的方式进行特征交互,具体可以参考如下链接:总结-空洞卷积(Dilated/Atrous Convolution)

本文采用第一种方式完成空洞卷积的网格效应消除

class ResidualBlock(nn.Module):def __init__(self, channel_num, dilation=1, group=1):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,bias=False)self.norm1 = nn.InstanceNorm2d(channel_num, affine=True)self.conv2 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,bias=False)self.norm2 = nn.InstanceNorm2d(channel_num, affine=True)def forward(self, x):y = F.relu(self.norm1(self.conv1(x)))y = self.norm2(self.conv2(y))return F.relu(x + y)

4.2 Gated Fusion Sub-network

其实现过程如下:首先提取低中高三个不同level的特征图,并设计一个gated fusion sub_network ,输出是三个层级的特征的权重,最后将三个不同层级特征图与对应权重线性连接即可。
具体公式如下:
在这里插入图片描述
文中提及Gated Fusion Sub-network 包含一个卷积核大小为3*3的卷积网络,输入是低中高三个level的特征通过通道维度进行连接,输出特征是3个通道
在这里插入图片描述

4.3 网络结构

首先使用三个卷积当作编码模块,对输入图像进行编码,最后一个卷积块特征分辨率减半;其次使用7个残差block对编码的特征进行特征增强;最后使用一个反卷积将特征图上采样两倍,接着使用两个反卷积将特征图映射回图像空间,这样就可以得到原图与无雾图的残差值;其中除了最后一个卷积层以及所设计的共享分离卷积层外,每个卷积后面都跟随一个instance normalization 以及一个ReLU激活函数。

PS:输入的参数除了原始的雾图外,还需要将图像的边缘提取后作为一个辅助信息加到输入信息中;实际操作的时候,可以提前将图片的边缘信息提取出来与原始图在通道上叠加进而送入网络,这样有利于网络学习
具体代码如下:

class GCANet(nn.Module):def __init__(self, in_c=4, out_c=3, only_residual=True):super(GCANet, self).__init__()self.conv1 = nn.Conv2d(in_c, 64, 3, 1, 1, bias=False)self.norm1 = nn.InstanceNorm2d(64, affine=True)self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)self.norm2 = nn.InstanceNorm2d(64, affine=True)self.conv3 = nn.Conv2d(64, 64, 3, 2, 1, bias=False)self.norm3 = nn.InstanceNorm2d(64, affine=True)self.res1 = SmoothDilatedResidualBlock(64, dilation=2)self.res2 = SmoothDilatedResidualBlock(64, dilation=2)self.res3 = SmoothDilatedResidualBlock(64, dilation=2)self.res4 = SmoothDilatedResidualBlock(64, dilation=4)self.res5 = SmoothDilatedResidualBlock(64, dilation=4)self.res6 = SmoothDilatedResidualBlock(64, dilation=4)self.res7 = ResidualBlock(64, dilation=1)self.gate = nn.Conv2d(64 * 3, 3, 3, 1, 1, bias=True)self.deconv3 = nn.ConvTranspose2d(64, 64, 4, 2, 1)self.norm4 = nn.InstanceNorm2d(64, affine=True)self.deconv2 = nn.Conv2d(64, 64, 3, 1, 1)self.norm5 = nn.InstanceNorm2d(64, affine=True)self.deconv1 = nn.Conv2d(64, out_c, 1)self.only_residual = only_residualdef forward(self, x):y = F.relu(self.norm1(self.conv1(x)))y = F.relu(self.norm2(self.conv2(y)))y1 = F.relu(self.norm3(self.conv3(y)))y = self.res1(y1)y = self.res2(y)y = self.res3(y)y2 = self.res4(y)y = self.res5(y2)y = self.res6(y)y3 = self.res7(y)gates = self.gate(torch.cat((y1, y2, y3), dim=1))gated_y = y1 * gates[:, [0], :, :] + y2 * gates[:, [1], :, :] + y3 * gates[:, [2], :, :]y = F.relu(self.norm4(self.deconv3(gated_y)))y = F.relu(self.norm5(self.deconv2(y)))if self.only_residual:y = self.deconv1(y)else:y = F.relu(self.deconv1(y))return y

4.4 损失函数

损失函数用的MSE Loss,作者提及可以使用其它的损失函数,例如perceptual loss或者GAN loss都可以提升最终的去雾效果,但是即使使用最简单的MSE也能得到SOTA的结果
在这里插入图片描述

  • 后记
    作者的改进重点
    发力在损失函数改进以及视频去雾方面

这篇关于Gated Context Aggregation Network for Image Dehazing and Deraining(GCANet)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/387586

相关文章

poj 2349 Arctic Network uva 10369(prim or kruscal最小生成树)

题目很麻烦,因为不熟悉最小生成树的算法调试了好久。 感觉网上的题目解释都没说得很清楚,不适合新手。自己写一个。 题意:给你点的坐标,然后两点间可以有两种方式来通信:第一种是卫星通信,第二种是无线电通信。 卫星通信:任何两个有卫星频道的点间都可以直接建立连接,与点间的距离无关; 无线电通信:两个点之间的距离不能超过D,无线电收发器的功率越大,D越大,越昂贵。 计算无线电收发器D

lvgl8.3.6 控件垂直布局 label控件在image控件的下方显示

在使用 LVGL 8.3.6 创建一个垂直布局,其中 label 控件位于 image 控件下方,你可以使用 lv_obj_set_flex_flow 来设置布局为垂直,并确保 label 控件在 image 控件后添加。这里是如何步骤性地实现它的一个基本示例: 创建父容器:首先创建一个容器对象,该对象将作为布局的基础。设置容器为垂直布局:使用 lv_obj_set_flex_flow 设置容器

图神经网络框架DGL实现Graph Attention Network (GAT)笔记

参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 ——基础操作&消息传递 [3]Cora数据集介绍+python读取 一、DGL实现GAT分类机器学习论文 程序摘自[1],该程序实现了利用图神经网络框架——DGL,实现图注意网络(GAT)。应用demo为对机器学习论文数据集——Cora,对论文所属类别进行分类。(下图摘自[3]) 1. 程序 Ubuntu:18.04

context:component-scan使用说明!

<!-- 使用annotation 自动注册bean, 并保证@Required、@Autowired的属性被注入 --> <context:component-scan base-package="com.yuanls"/> 在xml配置了这个标签后,spring可以自动去扫描base-pack下面或者子包下面的java文件,如果扫描到有@Component @Controll

深度学习--对抗生成网络(GAN, Generative Adversarial Network)

对抗生成网络(GAN, Generative Adversarial Network)是一种深度学习模型,由Ian Goodfellow等人在2014年提出。GAN主要用于生成数据,通过两个神经网络相互对抗,来生成以假乱真的新数据。以下是对GAN的详细阐述,包括其概念、作用、核心要点、实现过程、代码实现和适用场景。 1. 概念 GAN由两个神经网络组成:生成器(Generator)和判别器(D

React的context学习总结

context是干什么的?为什么会存在这么一个东西? context字面意思是上下文,在react中存在是为了解决深层次组件传值困难的问题 这里涉及到组件的传值问题,大体商说分三总:兄弟间传值(通过父组件),父往子传值(通过props),子往父传(props函数回调),这是基础的传值问题,但是如果组件嵌套的太深,那么传值就变的非常麻烦,为了解决这样的问题才产生了context  这是cont

Neighborhood Homophily-based Graph Convolutional Network

#paper/ccfB 推荐指数: #paper/⭐ #pp/图结构学习 流程 重定义同配性指标: N H i k = ∣ N ( i , k , c m a x ) ∣ ∣ N ( i , k ) ∣ with c m a x = arg ⁡ max ⁡ c ∈ [ 1 , C ] ∣ N ( i , k , c ) ∣ NH_i^k=\frac{|\mathcal{N}(i,k,c_{

兔子--The method setLatestEventInfo(Context, CharSequence, CharSequence, PendingIntent) from the type

notification.setLatestEventInfo(context, title, message, pendingIntent);     不建议使用 低于API Level 11版本,也就是Android 2.3.3以下的系统中,setLatestEventInfo()函数是唯一的实现方法。  Intent  intent = new Intent(

F12抓包05:Network接口测试(抓包篡改请求)

课程大纲         使用线上接口测试网站演示操作,浏览器F12检查工具如何进行简单的接口测试:抓包、复制请求、篡改数据、发送新请求。         测试地址:https://httpbin.org/forms/post ① 抓包:鼠标右键打开“检查”工具(F12),tab导航选择“网络”(Network),输入前3项点击提交,可看到录制的请求和返回数据。

OpenSNN推文:神经网络(Neural Network)相关论文最新推荐(九月份)(一)

基于卷积神经网络的活动识别分析系统及应用 论文链接:oalib简介:  活动识别技术在智能家居、运动评估和社交等领域得到广泛应用。本文设计了一种基于卷积神经网络的活动识别分析与应用系统,通过分析基于Android搭建的前端采所集的三向加速度传感器数据,对用户的当前活动进行识别。实验表明活动识别准确率满足了应用需求。本文基于识别的活动进行卡路里消耗计算,根据用户具体的活动、时间以及体重计算出相应活