小波卷积:为计算机视觉任务开辟新的参数效率之路

2024-08-24 16:36

本文主要是介绍小波卷积:为计算机视觉任务开辟新的参数效率之路,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文复述

这篇论文介绍了一种创新的卷积神经网络层——WTConv,它通过小波变换技术显著扩展了CNN的感受野,同时保持了参数效率。WTConv层能够实现对输入数据的多频率响应,增强了模型对形状而非纹理的特征识别能力,提高了在图像分类、语义分割和目标检测等视觉任务中的性能和鲁棒性。论文通过广泛的实验验证了WTConv的有效性,并展示了其在不同视觉任务中的应用潜力。

论文地址: https://arxiv.org/abs/2407.05848

摘要

论文指出,近年来尝试通过增加卷积核的大小来模仿视觉变换器(Vision Transformers, ViTs)自注意力模块的全局感受野,但这种方法很快遇到了上限,并且在达到全局感受野之前就饱和了。作者展示了通过利用小波变换(WT),实际上可以不遭受过度参数化的问题,获得非常大的感受野。例如,对于一个k×k的感受野,所提出方法中可训练参数的数量仅以k的对数级增长。提出的层名为WTConv,可以作为现有架构中的替代品,有效响应多频率,并随着感受野大小的增加而优雅地扩展。通过在ConvNeXt和MobileNetV2架构中展示WTConv层的有效性,以及作为下游任务的骨干网络,并展示了它带来的额外属性,如对图像损坏的鲁棒性增加以及对形状而非纹理的响应增加。

引言

引言指出了卷积神经网络(CNN)在计算机视觉领域的统治地位正受到视觉变换器(ViTs)的挑战,特别是由于ViTs的多头自注意力层能够实现全局特征混合。为了缩小CNN和ViTs之间的性能差距,研究人员尝试通过增大卷积核来增加感受野,但这种方法遇到了饱和问题。论文提出了一个问题:是否有可能在不增加过多参数的情况下,利用信号处理工具有效增加卷积的感受野,从而提高性能。

总结

论文成功地利用小波变换(WT)提出了WTConv层,这是一种新的CNN层,能够在不大幅增加参数的情况下显著增加感受野。WTConv层通过在小波域中进行卷积操作,实现了对输入数据的多频率响应,这使得网络能够更好地捕捉低频信息,从而提高了对形状的敏感性,并增强了网络的鲁棒性。实验结果表明,WTConv层在多个视觉任务中都取得了性能提升,证明了其有效性。

全文要点

WTConv

WTConv(Wavelet Transform Convolution)是一种基于小波变换的卷积层,它旨在为卷积神经网络(CNN)提供更大的感受野,同时避免因使用大卷积核而带来的参数数量急剧增加的问题。WTConv是一种创新的卷积神经网络层,它通过小波变换技术实现了对输入数据的深层次和多尺度分析。以下是WTConv的几个关键特点和工作原理的详细概括:

  1. 小波变换的应用:WTConv使用小波变换对输入信号进行分解,这允许网络在不同的频率和空间尺度上捕捉信息。小波变换提供了一种将信号分解为可提供时间和频率信息的组成部分的方法。

  2. 感受野的显著扩展:通过小波变换的多级分解,WTConv能够在保持参数数量相对较低的同时,实现对输入数据更大范围的覆盖。这意味着即使是小的卷积核也能够通过小波变换捕捉到更广泛的上下文信息。

  3. 参数效率与性能提升:WTConv的设计减少了模型参数的数量,与传统的大卷积核相比,它以参数数量的对数级增长实现了感受野的扩展。这种效率的提升使得WTConv在保持计算成本较低的同时,能够提高模型在图像分类、语义分割等任务上的性能。

  4. 多频率特征的独立处理:WTConv允许网络对分解出的不同频率特征进行独立的卷积处理,这增强了模型对信号中不同特征的响应能力,特别是对低频特征的捕捉,这对于理解图像中的形状和结构非常重要。

  5. 小波反变换的集成:在小波域中处理完信号后,WTConv利用小波反变换将处理后的信号重新组合,以生成最终的输出。这一步骤确保了信号的完整性,并允许网络在原始域中进行最终的特征整合。

WTConv通过这些设计,有效地结合了小波变换的多尺度分析能力和卷积神经网络的深度学习能力,为解决计算机视觉中的复杂问题提供了一种新的工具。

wt(Wavelet Transform)

小波变换(Wavelet Transform, WT)是一种数学变换,用于将信号分解成不同时间尺度上的成分,这些成分能够提供信号的时频信息。它广泛应用于信号处理、图像分析、数据压缩和其他许多领域。以下是小波变换的几个关键特点:

  1. 时频联合表示:与仅提供频率信息的傅里叶变换相比,小波变换能够同时提供信号的时间(或空间)和频率信息,使得它在分析非平稳信号时特别有用。

  2. 多尺度分析能力:小波变换通过在不同的尺度上分析信号,能够揭示信号在不同分辨率下的特性。这种多尺度分解使得小波变换能够适应信号的局部变化,捕捉到重要细节。

  3. 正交小波基:在某些小波变换中,如Haar小波变换,变换基是正交的,这允许无失真地从变换后的系数重构原始信号,保证了变换的逆过程的准确性。

  4. 稀疏性优势:小波变换通常能够产生稀疏的系数矩阵,其中许多系数为零或很小,这不仅有助于数据压缩,还可以在信号去噪和特征提取中发挥作用。

  5. 计算效率:小波变换可以通过快速算法实现,如快速小波变换(FWT),它减少了计算量,提高了处理速度。

小波变换的这些特性使其成为分析和处理信号的理想选择,特别是在需要同时考虑时间和频率信息的复杂场景中。

iwt

小波反变换(Inverse Wavelet Transform, IWT)是小波变换的逆过程,它用于从小波变换的系数中重构原始信号。以下是IWT的关键特点和工作原理:

  1. 信号重构:IWT的主要目的是将小波变换产生的系数转换回原始的信号或数据。这是通过使用小波变换时定义的相同小波函数来实现的,但是以相反的顺序。

  2. 逆过程:IWT是小波变换的逆操作,它利用了小波变换的正交性质,特别是当使用正交小波基时,可以确保信号的精确重构。

  3. 多尺度合成:在多级小波分解的情况下,IWT通过逐步合成不同尺度(或分辨率)上的细节信息来重构信号。这包括将低频和高频成分重新组合。

  4. 系数的整合:IWT通过整合小波变换产生的所有系数,包括近似系数(Approximation coefficients)和细节系数(Detail coefficients),来恢复原始数据。

  5. 计算流程:IWT的计算通常涉及从最粗糙的尺度开始,逐步向上细化至更高尺度的过程。每一步都涉及到将当前尺度的系数与小波函数相结合,以及将从更粗糙尺度上恢复的信息逐步添加进来。

  6. 稀疏性利用:如果小波变换产生了稀疏系数,IWT可以利用这一特性来减少计算量,因为许多接近零的系数可以被忽略或近似处理。

  7. 与WT的兼容性:IWT与小波变换紧密兼容,确保了变换和反变换过程的一致性,这对于保持信号的完整性至关重要。

小波反变换是小波分析中不可或缺的一部分,它确保了小波变换的实用性和有效性,特别是在需要从变换后的系数中恢复原始信号的场景中。

pytorch代码实现

源自:https://github.com/BGU-CS-VIL/WTConv

import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import pywt.datafrom functools import partialdef create_wavelet_filter(wave, in_size, out_size, type=torch.float):w = pywt.Wavelet(wave)dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)return dec_filters, rec_filtersdef wavelet_transform(x, filters):b, c, h, w = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)x = x.reshape(b, c, 4, h // 2, w // 2)return xdef inverse_wavelet_transform(x, filters):b, c, _, h_half, w_half = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = x.reshape(b, c * 4, h_half, w_half)x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)return xclass _ScaleModule(nn.Module):def __init__(self, dims, init_scale=1.0, init_bias=0):super(_ScaleModule, self).__init__()self.dims = dimsself.weight = nn.Parameter(torch.ones(*dims) * init_scale)self.bias = Nonedef forward(self, x):return torch.mul(self.weight, x)class WTConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):super(WTConv2d, self).__init__()assert in_channels == out_channelsself.in_channels = in_channelsself.wt_levels = wt_levelsself.stride = strideself.dilation = 1self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)self.wt_function = partial(wavelet_transform, filters=self.wt_filter)self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels, bias=bias)self.base_scale = _ScaleModule([1, in_channels, 1, 1])self.wavelet_convs = nn.ModuleList([nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)])self.wavelet_scale = nn.ModuleList([_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)])if self.stride > 1:self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,groups=in_channels)else:self.do_stride = Nonedef forward(self, x):x_ll_in_levels = []x_h_in_levels = []shapes_in_levels = []curr_x_ll = xfor i in range(self.wt_levels):curr_shape = curr_x_ll.shapeshapes_in_levels.append(curr_shape)if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)curr_x_ll = F.pad(curr_x_ll, curr_pads)curr_x = self.wt_function(curr_x_ll)curr_x_ll = curr_x[:, :, 0, :, :]shape_x = curr_x.shapecurr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))curr_x_tag = curr_x_tag.reshape(shape_x)x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])next_x_ll = 0for i in range(self.wt_levels - 1, -1, -1):curr_x_ll = x_ll_in_levels.pop()curr_x_h = x_h_in_levels.pop()curr_shape = shapes_in_levels.pop()curr_x_ll = curr_x_ll + next_x_llcurr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)next_x_ll = self.iwt_function(curr_x)next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]x_tag = next_x_llassert len(x_ll_in_levels) == 0x = self.base_scale(self.base_conv(x))x = x + x_tagif self.do_stride is not None:x = self.do_stride(x)return xx = torch.randn((4, 64, 128, 128))
model = WTConv2d(in_channels=64, out_channels=64)
out = model(x)
print(out.shape)

这篇关于小波卷积:为计算机视觉任务开辟新的参数效率之路的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1103016

相关文章

MySQL中时区参数time_zone解读

《MySQL中时区参数time_zone解读》MySQL时区参数time_zone用于控制系统函数和字段的DEFAULTCURRENT_TIMESTAMP属性,修改时区可能会影响timestamp类型... 目录前言1.时区参数影响2.如何设置3.字段类型选择总结前言mysql 时区参数 time_zon

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

Python如何使用seleniumwire接管Chrome查看控制台中参数

《Python如何使用seleniumwire接管Chrome查看控制台中参数》文章介绍了如何使用Python的seleniumwire库来接管Chrome浏览器,并通过控制台查看接口参数,本文给大家... 1、cmd打开控制台,启动谷歌并制定端口号,找不到文件的加环境变量chrome.exe --rem

Python Invoke自动化任务库的使用

《PythonInvoke自动化任务库的使用》Invoke是一个强大的Python库,用于编写自动化脚本,本文就来介绍一下PythonInvoke自动化任务库的使用,具有一定的参考价值,感兴趣的可以... 目录什么是 Invoke?如何安装 Invoke?Invoke 基础1. 运行测试2. 构建文档3.

解决Cron定时任务中Pytest脚本无法发送邮件的问题

《解决Cron定时任务中Pytest脚本无法发送邮件的问题》文章探讨解决在Cron定时任务中运行Pytest脚本时邮件发送失败的问题,先优化环境变量,再检查Pytest邮件配置,接着配置文件确保SMT... 目录引言1. 环境变量优化:确保Cron任务可以正确执行解决方案:1.1. 创建一个脚本1.2. 修

Linux中Curl参数详解实践应用

《Linux中Curl参数详解实践应用》在现代网络开发和运维工作中,curl命令是一个不可或缺的工具,它是一个利用URL语法在命令行下工作的文件传输工具,支持多种协议,如HTTP、HTTPS、FTP等... 目录引言一、基础请求参数1. -X 或 --request2. -d 或 --data3. -H 或

Java实现任务管理器性能网络监控数据的方法详解

《Java实现任务管理器性能网络监控数据的方法详解》在现代操作系统中,任务管理器是一个非常重要的工具,用于监控和管理计算机的运行状态,包括CPU使用率、内存占用等,对于开发者和系统管理员来说,了解这些... 目录引言一、背景知识二、准备工作1. Maven依赖2. Gradle依赖三、代码实现四、代码详解五

如何使用celery进行异步处理和定时任务(django)

《如何使用celery进行异步处理和定时任务(django)》文章介绍了Celery的基本概念、安装方法、如何使用Celery进行异步任务处理以及如何设置定时任务,通过Celery,可以在Web应用中... 目录一、celery的作用二、安装celery三、使用celery 异步执行任务四、使用celery

什么是cron? Linux系统下Cron定时任务使用指南

《什么是cron?Linux系统下Cron定时任务使用指南》在日常的Linux系统管理和维护中,定时执行任务是非常常见的需求,你可能需要每天执行备份任务、清理系统日志或运行特定的脚本,而不想每天... 在管理 linux 服务器的过程中,总有一些任务需要我们定期或重复执行。就比如备份任务,通常会选在服务器资

如何测试计算机的内存是否存在问题? 判断电脑内存故障的多种方法

《如何测试计算机的内存是否存在问题?判断电脑内存故障的多种方法》内存是电脑中非常重要的组件之一,如果内存出现故障,可能会导致电脑出现各种问题,如蓝屏、死机、程序崩溃等,如何判断内存是否出现故障呢?下... 如果你的电脑是崩溃、冻结还是不稳定,那么它的内存可能有问题。要进行检查,你可以使用Windows 11