本文主要是介绍Gated Context Aggregation Network for Image Dehazing and Deraining(GCANet),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1 总体概述
GCANet是端到端去雾的一篇代表性的文章,它摒弃以往使用手工设计的先验以及大气散射模型的使用,直接通过原始有雾图像估计出无雾图像J与有雾图像I之间的残差,图像恢复阶段直接使用网络输出的残差与输入有雾图像I之间的加和完成去雾过程。
文章本身最大的贡献:
1、借鉴并使用了平滑空洞卷积消除以往空洞卷积存在的网格伪影以及特征相关性不强的问题,提出了一个门限子网络,用于依据不同level的特征的权重进行特征加权融合
2、GCANet达到当前的SOTA,并且使用消融试验对不同模块重要性进行了分析
3、GCANet应用到去雨任务中依然获得了SOTA
2 灵感来源
之前的研究者利用扩张卷积来聚合上下文信息,可以获得更加细腻和准确的结果,主要原因是扩张卷积不损失分辨率,但是他也存在一些问题,比如网格伪影,远距离信息没有相关性。因此也有很多人去尝试改进上述问题,比如使用平滑空洞卷积来消除网格伪影;也有人使用不同level的特征图进行融合获得更好的去雾效果;也有使用使用gated fusion module ,但是它是直接使用原始图的拷贝而非中间所获取的特征图;GCANet借鉴了上述思想,使用了扩张卷积,借鉴了smooth 扩张卷积消除伪影,借鉴了融合特征图的思路,提出了一个门限子网络用于辨析不同level特征图的重要性
3 现有工作分析
去雾分为两种,其一是基于传统先验知识的去雾,其二是基于学习的方式, 区别就是第一种方案通过手工获取的先验知识在第二钟方案是通过学习获取
传统方案去雾
1、基于暗通道先验以及其对应的优化方案
2、最大对比度
3、颜色衰减先验等
深度学习去雾
1、使用端到端的深度学习方式,利用多尺度网络预测透射率图,但是透射率图估计的不准确导致去雾结果较差
2、将全局大气光值A以及透射率参数融合为一个参数,使用轻量级的网络进行预测
3、也有使用两个子网络分别预测全局大气光值A以及透射率参数,并依据大气散射模型进行图像去雾的
4 本文GCANet方法
整体架构是:首先使用编码模块将输入的有雾图像编码为特征;接着通过聚合上下文信息以及融合不同level的特征强化编码特征(主要使用平滑空洞卷积以及特殊设计的门限子网络);最后使用一个解码网络将特征映射回原图空间,并加上原始图就可以获取最终的去雾图像
4.1 Smoothed Dilated Convolution
什么是网格伪影?
由上图可知,最右边的这一幅特征图中的红蓝绿黄色四种小点来自于之前特征层对应颜色的独立特征,特征之间没有交互,没有融合,导致最终获取的当前层的特征的之间没有相关性可言,造成局部信息丢失,这对于pixel_level的预测来说是极其致命的。
消除网格伪影有两种方式,第一种是在使用空洞卷积之前,使用共享可分离卷积先进行特征之间的融合;另一种方式是在卷积后特征整合之前,使用类似于shuffleNet一样的方式进行特征交互,具体可以参考如下链接:总结-空洞卷积(Dilated/Atrous Convolution)
本文采用第一种方式完成空洞卷积的网格效应消除
class ResidualBlock(nn.Module):def __init__(self, channel_num, dilation=1, group=1):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,bias=False)self.norm1 = nn.InstanceNorm2d(channel_num, affine=True)self.conv2 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,bias=False)self.norm2 = nn.InstanceNorm2d(channel_num, affine=True)def forward(self, x):y = F.relu(self.norm1(self.conv1(x)))y = self.norm2(self.conv2(y))return F.relu(x + y)
4.2 Gated Fusion Sub-network
其实现过程如下:首先提取低中高三个不同level的特征图,并设计一个gated fusion sub_network ,输出是三个层级的特征的权重,最后将三个不同层级特征图与对应权重线性连接即可。
具体公式如下:
文中提及Gated Fusion Sub-network 包含一个卷积核大小为3*3的卷积网络,输入是低中高三个level的特征通过通道维度进行连接,输出特征是3个通道
4.3 网络结构
首先使用三个卷积当作编码模块,对输入图像进行编码,最后一个卷积块特征分辨率减半;其次使用7个残差block对编码的特征进行特征增强;最后使用一个反卷积将特征图上采样两倍,接着使用两个反卷积将特征图映射回图像空间,这样就可以得到原图与无雾图的残差值;其中除了最后一个卷积层以及所设计的共享分离卷积层外,每个卷积后面都跟随一个instance normalization 以及一个ReLU激活函数。
PS:输入的参数除了原始的雾图外,还需要将图像的边缘提取后作为一个辅助信息加到输入信息中;实际操作的时候,可以提前将图片的边缘信息提取出来与原始图在通道上叠加进而送入网络,这样有利于网络学习
具体代码如下:
class GCANet(nn.Module):def __init__(self, in_c=4, out_c=3, only_residual=True):super(GCANet, self).__init__()self.conv1 = nn.Conv2d(in_c, 64, 3, 1, 1, bias=False)self.norm1 = nn.InstanceNorm2d(64, affine=True)self.conv2 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)self.norm2 = nn.InstanceNorm2d(64, affine=True)self.conv3 = nn.Conv2d(64, 64, 3, 2, 1, bias=False)self.norm3 = nn.InstanceNorm2d(64, affine=True)self.res1 = SmoothDilatedResidualBlock(64, dilation=2)self.res2 = SmoothDilatedResidualBlock(64, dilation=2)self.res3 = SmoothDilatedResidualBlock(64, dilation=2)self.res4 = SmoothDilatedResidualBlock(64, dilation=4)self.res5 = SmoothDilatedResidualBlock(64, dilation=4)self.res6 = SmoothDilatedResidualBlock(64, dilation=4)self.res7 = ResidualBlock(64, dilation=1)self.gate = nn.Conv2d(64 * 3, 3, 3, 1, 1, bias=True)self.deconv3 = nn.ConvTranspose2d(64, 64, 4, 2, 1)self.norm4 = nn.InstanceNorm2d(64, affine=True)self.deconv2 = nn.Conv2d(64, 64, 3, 1, 1)self.norm5 = nn.InstanceNorm2d(64, affine=True)self.deconv1 = nn.Conv2d(64, out_c, 1)self.only_residual = only_residualdef forward(self, x):y = F.relu(self.norm1(self.conv1(x)))y = F.relu(self.norm2(self.conv2(y)))y1 = F.relu(self.norm3(self.conv3(y)))y = self.res1(y1)y = self.res2(y)y = self.res3(y)y2 = self.res4(y)y = self.res5(y2)y = self.res6(y)y3 = self.res7(y)gates = self.gate(torch.cat((y1, y2, y3), dim=1))gated_y = y1 * gates[:, [0], :, :] + y2 * gates[:, [1], :, :] + y3 * gates[:, [2], :, :]y = F.relu(self.norm4(self.deconv3(gated_y)))y = F.relu(self.norm5(self.deconv2(y)))if self.only_residual:y = self.deconv1(y)else:y = F.relu(self.deconv1(y))return y
4.4 损失函数
损失函数用的MSE Loss,作者提及可以使用其它的损失函数,例如perceptual loss或者GAN loss都可以提升最终的去雾效果,但是即使使用最简单的MSE也能得到SOTA的结果
- 后记
作者的改进重点:
发力在损失函数改进以及视频去雾方面
这篇关于Gated Context Aggregation Network for Image Dehazing and Deraining(GCANet)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!