小波卷积:为计算机视觉任务开辟新的参数效率之路

2024-08-24 16:36

本文主要是介绍小波卷积:为计算机视觉任务开辟新的参数效率之路,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

论文复述

这篇论文介绍了一种创新的卷积神经网络层——WTConv,它通过小波变换技术显著扩展了CNN的感受野,同时保持了参数效率。WTConv层能够实现对输入数据的多频率响应,增强了模型对形状而非纹理的特征识别能力,提高了在图像分类、语义分割和目标检测等视觉任务中的性能和鲁棒性。论文通过广泛的实验验证了WTConv的有效性,并展示了其在不同视觉任务中的应用潜力。

论文地址: https://arxiv.org/abs/2407.05848

摘要

论文指出,近年来尝试通过增加卷积核的大小来模仿视觉变换器(Vision Transformers, ViTs)自注意力模块的全局感受野,但这种方法很快遇到了上限,并且在达到全局感受野之前就饱和了。作者展示了通过利用小波变换(WT),实际上可以不遭受过度参数化的问题,获得非常大的感受野。例如,对于一个k×k的感受野,所提出方法中可训练参数的数量仅以k的对数级增长。提出的层名为WTConv,可以作为现有架构中的替代品,有效响应多频率,并随着感受野大小的增加而优雅地扩展。通过在ConvNeXt和MobileNetV2架构中展示WTConv层的有效性,以及作为下游任务的骨干网络,并展示了它带来的额外属性,如对图像损坏的鲁棒性增加以及对形状而非纹理的响应增加。

引言

引言指出了卷积神经网络(CNN)在计算机视觉领域的统治地位正受到视觉变换器(ViTs)的挑战,特别是由于ViTs的多头自注意力层能够实现全局特征混合。为了缩小CNN和ViTs之间的性能差距,研究人员尝试通过增大卷积核来增加感受野,但这种方法遇到了饱和问题。论文提出了一个问题:是否有可能在不增加过多参数的情况下,利用信号处理工具有效增加卷积的感受野,从而提高性能。

总结

论文成功地利用小波变换(WT)提出了WTConv层,这是一种新的CNN层,能够在不大幅增加参数的情况下显著增加感受野。WTConv层通过在小波域中进行卷积操作,实现了对输入数据的多频率响应,这使得网络能够更好地捕捉低频信息,从而提高了对形状的敏感性,并增强了网络的鲁棒性。实验结果表明,WTConv层在多个视觉任务中都取得了性能提升,证明了其有效性。

全文要点

WTConv

WTConv(Wavelet Transform Convolution)是一种基于小波变换的卷积层,它旨在为卷积神经网络(CNN)提供更大的感受野,同时避免因使用大卷积核而带来的参数数量急剧增加的问题。WTConv是一种创新的卷积神经网络层,它通过小波变换技术实现了对输入数据的深层次和多尺度分析。以下是WTConv的几个关键特点和工作原理的详细概括:

  1. 小波变换的应用:WTConv使用小波变换对输入信号进行分解,这允许网络在不同的频率和空间尺度上捕捉信息。小波变换提供了一种将信号分解为可提供时间和频率信息的组成部分的方法。

  2. 感受野的显著扩展:通过小波变换的多级分解,WTConv能够在保持参数数量相对较低的同时,实现对输入数据更大范围的覆盖。这意味着即使是小的卷积核也能够通过小波变换捕捉到更广泛的上下文信息。

  3. 参数效率与性能提升:WTConv的设计减少了模型参数的数量,与传统的大卷积核相比,它以参数数量的对数级增长实现了感受野的扩展。这种效率的提升使得WTConv在保持计算成本较低的同时,能够提高模型在图像分类、语义分割等任务上的性能。

  4. 多频率特征的独立处理:WTConv允许网络对分解出的不同频率特征进行独立的卷积处理,这增强了模型对信号中不同特征的响应能力,特别是对低频特征的捕捉,这对于理解图像中的形状和结构非常重要。

  5. 小波反变换的集成:在小波域中处理完信号后,WTConv利用小波反变换将处理后的信号重新组合,以生成最终的输出。这一步骤确保了信号的完整性,并允许网络在原始域中进行最终的特征整合。

WTConv通过这些设计,有效地结合了小波变换的多尺度分析能力和卷积神经网络的深度学习能力,为解决计算机视觉中的复杂问题提供了一种新的工具。

wt(Wavelet Transform)

小波变换(Wavelet Transform, WT)是一种数学变换,用于将信号分解成不同时间尺度上的成分,这些成分能够提供信号的时频信息。它广泛应用于信号处理、图像分析、数据压缩和其他许多领域。以下是小波变换的几个关键特点:

  1. 时频联合表示:与仅提供频率信息的傅里叶变换相比,小波变换能够同时提供信号的时间(或空间)和频率信息,使得它在分析非平稳信号时特别有用。

  2. 多尺度分析能力:小波变换通过在不同的尺度上分析信号,能够揭示信号在不同分辨率下的特性。这种多尺度分解使得小波变换能够适应信号的局部变化,捕捉到重要细节。

  3. 正交小波基:在某些小波变换中,如Haar小波变换,变换基是正交的,这允许无失真地从变换后的系数重构原始信号,保证了变换的逆过程的准确性。

  4. 稀疏性优势:小波变换通常能够产生稀疏的系数矩阵,其中许多系数为零或很小,这不仅有助于数据压缩,还可以在信号去噪和特征提取中发挥作用。

  5. 计算效率:小波变换可以通过快速算法实现,如快速小波变换(FWT),它减少了计算量,提高了处理速度。

小波变换的这些特性使其成为分析和处理信号的理想选择,特别是在需要同时考虑时间和频率信息的复杂场景中。

iwt

小波反变换(Inverse Wavelet Transform, IWT)是小波变换的逆过程,它用于从小波变换的系数中重构原始信号。以下是IWT的关键特点和工作原理:

  1. 信号重构:IWT的主要目的是将小波变换产生的系数转换回原始的信号或数据。这是通过使用小波变换时定义的相同小波函数来实现的,但是以相反的顺序。

  2. 逆过程:IWT是小波变换的逆操作,它利用了小波变换的正交性质,特别是当使用正交小波基时,可以确保信号的精确重构。

  3. 多尺度合成:在多级小波分解的情况下,IWT通过逐步合成不同尺度(或分辨率)上的细节信息来重构信号。这包括将低频和高频成分重新组合。

  4. 系数的整合:IWT通过整合小波变换产生的所有系数,包括近似系数(Approximation coefficients)和细节系数(Detail coefficients),来恢复原始数据。

  5. 计算流程:IWT的计算通常涉及从最粗糙的尺度开始,逐步向上细化至更高尺度的过程。每一步都涉及到将当前尺度的系数与小波函数相结合,以及将从更粗糙尺度上恢复的信息逐步添加进来。

  6. 稀疏性利用:如果小波变换产生了稀疏系数,IWT可以利用这一特性来减少计算量,因为许多接近零的系数可以被忽略或近似处理。

  7. 与WT的兼容性:IWT与小波变换紧密兼容,确保了变换和反变换过程的一致性,这对于保持信号的完整性至关重要。

小波反变换是小波分析中不可或缺的一部分,它确保了小波变换的实用性和有效性,特别是在需要从变换后的系数中恢复原始信号的场景中。

pytorch代码实现

源自:https://github.com/BGU-CS-VIL/WTConv

import torch
import torch.nn as nn
import torch.nn.functional as F
import pywt
import pywt.datafrom functools import partialdef create_wavelet_filter(wave, in_size, out_size, type=torch.float):w = pywt.Wavelet(wave)dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)return dec_filters, rec_filtersdef wavelet_transform(x, filters):b, c, h, w = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)x = x.reshape(b, c, 4, h // 2, w // 2)return xdef inverse_wavelet_transform(x, filters):b, c, _, h_half, w_half = x.shapepad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)x = x.reshape(b, c * 4, h_half, w_half)x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)return xclass _ScaleModule(nn.Module):def __init__(self, dims, init_scale=1.0, init_bias=0):super(_ScaleModule, self).__init__()self.dims = dimsself.weight = nn.Parameter(torch.ones(*dims) * init_scale)self.bias = Nonedef forward(self, x):return torch.mul(self.weight, x)class WTConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1'):super(WTConv2d, self).__init__()assert in_channels == out_channelsself.in_channels = in_channelsself.wt_levels = wt_levelsself.stride = strideself.dilation = 1self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)self.wt_function = partial(wavelet_transform, filters=self.wt_filter)self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels, bias=bias)self.base_scale = _ScaleModule([1, in_channels, 1, 1])self.wavelet_convs = nn.ModuleList([nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)])self.wavelet_scale = nn.ModuleList([_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)])if self.stride > 1:self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,groups=in_channels)else:self.do_stride = Nonedef forward(self, x):x_ll_in_levels = []x_h_in_levels = []shapes_in_levels = []curr_x_ll = xfor i in range(self.wt_levels):curr_shape = curr_x_ll.shapeshapes_in_levels.append(curr_shape)if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)curr_x_ll = F.pad(curr_x_ll, curr_pads)curr_x = self.wt_function(curr_x_ll)curr_x_ll = curr_x[:, :, 0, :, :]shape_x = curr_x.shapecurr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))curr_x_tag = curr_x_tag.reshape(shape_x)x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])next_x_ll = 0for i in range(self.wt_levels - 1, -1, -1):curr_x_ll = x_ll_in_levels.pop()curr_x_h = x_h_in_levels.pop()curr_shape = shapes_in_levels.pop()curr_x_ll = curr_x_ll + next_x_llcurr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)next_x_ll = self.iwt_function(curr_x)next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]x_tag = next_x_llassert len(x_ll_in_levels) == 0x = self.base_scale(self.base_conv(x))x = x + x_tagif self.do_stride is not None:x = self.do_stride(x)return xx = torch.randn((4, 64, 128, 128))
model = WTConv2d(in_channels=64, out_channels=64)
out = model(x)
print(out.shape)

这篇关于小波卷积:为计算机视觉任务开辟新的参数效率之路的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1103016

相关文章

Andrej Karpathy最新采访:认知核心模型10亿参数就够了,AI会打破教育不公的僵局

夕小瑶科技说 原创  作者 | 海野 AI圈子的红人,AI大神Andrej Karpathy,曾是OpenAI联合创始人之一,特斯拉AI总监。上一次的动态是官宣创办一家名为 Eureka Labs 的人工智能+教育公司 ,宣布将长期致力于AI原生教育。 近日,Andrej Karpathy接受了No Priors(投资博客)的采访,与硅谷知名投资人 Sara Guo 和 Elad G

C++11第三弹:lambda表达式 | 新的类功能 | 模板的可变参数

🌈个人主页: 南桥几晴秋 🌈C++专栏: 南桥谈C++ 🌈C语言专栏: C语言学习系列 🌈Linux学习专栏: 南桥谈Linux 🌈数据结构学习专栏: 数据结构杂谈 🌈数据库学习专栏: 南桥谈MySQL 🌈Qt学习专栏: 南桥谈Qt 🌈菜鸡代码练习: 练习随想记录 🌈git学习: 南桥谈Git 🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈🌈�

如何在页面调用utility bar并传递参数至lwc组件

1.在app的utility item中添加lwc组件: 2.调用utility bar api的方式有两种: 方法一,通过lwc调用: import {LightningElement,api ,wire } from 'lwc';import { publish, MessageContext } from 'lightning/messageService';import Ca

4B参数秒杀GPT-3.5:MiniCPM 3.0惊艳登场!

​ 面壁智能 在 AI 的世界里,总有那么几个时刻让人惊叹不已。面壁智能推出的 MiniCPM 3.0,这个仅有4B参数的"小钢炮",正在以惊人的实力挑战着 GPT-3.5 这个曾经的AI巨人。 MiniCPM 3.0 MiniCPM 3.0 MiniCPM 3.0 目前的主要功能有: 长上下文功能:原生支持 32k 上下文长度,性能完美。我们引入了

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点赞 👍 收藏 ⭐评论 📝 🍅 文末获取源码联系 👇🏻 精彩专栏推荐订阅 👇🏻 不然下次找不到哟~Java毕业设计项目~热门选题推荐《1000套》 目录 1.技术选型 2.开发工具 3.功能

AI(文生语音)-TTS 技术线路探索学习:从拼接式参数化方法到Tacotron端到端输出

AI(文生语音)-TTS 技术线路探索学习:从拼接式参数化方法到Tacotron端到端输出 在数字化时代,文本到语音(Text-to-Speech, TTS)技术已成为人机交互的关键桥梁,无论是为视障人士提供辅助阅读,还是为智能助手注入声音的灵魂,TTS 技术都扮演着至关重要的角色。从最初的拼接式方法到参数化技术,再到现今的深度学习解决方案,TTS 技术经历了一段长足的进步。这篇文章将带您穿越时

如何确定 Go 语言中 HTTP 连接池的最佳参数?

确定 Go 语言中 HTTP 连接池的最佳参数可以通过以下几种方式: 一、分析应用场景和需求 并发请求量: 确定应用程序在特定时间段内可能同时发起的 HTTP 请求数量。如果并发请求量很高,需要设置较大的连接池参数以满足需求。例如,对于一个高并发的 Web 服务,可能同时有数百个请求在处理,此时需要较大的连接池大小。可以通过压力测试工具模拟高并发场景,观察系统在不同并发请求下的性能表现,从而

多路转接之select(fd_set介绍,参数详细介绍),实现非阻塞式网络通信

目录 多路转接之select 引入 介绍 fd_set 函数原型 nfds readfds / writefds / exceptfds readfds  总结  fd_set操作接口  timeout timevalue 结构体 传入值 返回值 代码 注意点 -- 调用函数 select的参数填充  获取新连接 注意点 -- 通信时的调用函数 添加新fd到

计算机视觉工程师所需的基本技能

一、编程技能 熟练掌握编程语言 Python:在计算机视觉领域广泛应用,有丰富的库如 OpenCV、TensorFlow、PyTorch 等,方便进行算法实现和模型开发。 C++:运行效率高,适用于对性能要求严格的计算机视觉应用。 数据结构与算法 掌握常见的数据结构(如数组、链表、栈、队列、树、图等)和算法(如排序、搜索、动态规划等),能够优化代码性能,提高算法效率。 二、数学基础

FreeRTOS学习笔记(二)任务基础篇

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、 任务的基本内容1.1 任务的基本特点1.2 任务的状态1.3 任务控制块——任务的“身份证” 二、 任务的实现2.1 定义任务函数2.2 创建任务2.3 启动任务调度器2.4 任务的运行与切换2.4.1 利用延时函数2.4.2 利用中断 2.5 任务的通信与同步2.6 任务的删除2.7 任务的通知2