pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构)

本文主要是介绍pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        前面的博文读论文:Restormer: Efficient Transformer for High-Resolution Image Restoration 介绍了 Restormer 网络结构的网络技术特点,本文用 pytorch 实现其中的主要网络结构模块。

1. MDTA(Multi-Dconv Head Transposed Attention:多头注意力机制

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):def __init__(self, dim, num_heads, bias):super(Attention, self).__init__()self.num_heads = num_heads  # 注意力头的个数self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))  # 可学习系数# 1*1 升维self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)# 3*3 分组卷积self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)# 1*1 卷积self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def forward(self, x):b,c,h,w = x.shape  # 输入的结构 batch 数,通道数和高宽qkv = self.qkv_dwconv(self.qkv(x))q,k,v = qkv.chunk(3, dim=1)  #  第 1 个维度方向切分成 3 块# 改变 q, k, v 的结构为 b head c (h w),将每个二维 plane 展平q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)  # C 维度标准化,这里的 C 与通道维度略有不同k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperature # @ 是矩阵乘attn = attn.softmax(dim=-1)out = (attn @ v)  # 注意力图(严格来说不算图)# 将展平后的注意力图恢复out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)# 真正的注意力图out = self.project_out(out)return out

2. GDFN( Gated-Dconv Feed-Forward Network) 

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):def __init__(self, dim, ffn_expansion_factor, bias):super(FeedForward, self).__init__()# 隐藏层特征维度等于输入维度乘以扩张因子hidden_features = int(dim*ffn_expansion_factor)# 1*1 升维self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)# 3*3 分组卷积self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)# 1*1 降维self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)def forward(self, x):x = self.project_in(x)x1, x2 = self.dwconv(x).chunk(2, dim=1)  # 第 1 个维度方向切分成 2 块x = F.gelu(x1) * x2  # gelu 相当于 relu+dropoutx = self.project_out(x)return x

3. TransformerBlock

## 就是标准的 Transformer 架构
class TransformerBlock(nn.Module):def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):super(TransformerBlock, self).__init__()self.norm1 = LayerNorm(dim, LayerNorm_type)  # 层标准化self.attn = Attention(dim, num_heads, bias)  # 自注意力self.norm2 = LayerNorm(dim, LayerNorm_type)  # 层表转化self.ffn = FeedForward(dim, ffn_expansion_factor, bias)  # FFNdef forward(self, x):x = x + self.attn(self.norm1(x))  # 残差x = x + self.ffn(self.norm2(x))  # 残差return x

4. 测试样例

model = Restormer()
print(model)  # 打印网络结构x = torch.randn((1, 3, 512, 512))  #随机生成输入图像
x = model(x)  # 送入网络
print(x.shape) # 打印网络输入的图像结构

这篇关于pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/543179

相关文章

C#实现将Excel表格转换为图片(JPG/ PNG)

《C#实现将Excel表格转换为图片(JPG/PNG)》Excel表格可能会因为不同设备或字体缺失等问题,导致格式错乱或数据显示异常,转换为图片后,能确保数据的排版等保持一致,下面我们看看如何使用C... 目录通过C# 转换Excel工作表到图片通过C# 转换指定单元格区域到图片知识扩展C# 将 Excel

基于Java实现回调监听工具类

《基于Java实现回调监听工具类》这篇文章主要为大家详细介绍了如何基于Java实现一个回调监听工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录监听接口类 Listenable实际用法打印结果首先,会用到 函数式接口 Consumer, 通过这个可以解耦回调方法,下面先写一个

使用Java将DOCX文档解析为Markdown文档的代码实现

《使用Java将DOCX文档解析为Markdown文档的代码实现》在现代文档处理中,Markdown(MD)因其简洁的语法和良好的可读性,逐渐成为开发者、技术写作者和内容创作者的首选格式,然而,许多文... 目录引言1. 工具和库介绍2. 安装依赖库3. 使用Apache POI解析DOCX文档4. 将解析

Qt中QGroupBox控件的实现

《Qt中QGroupBox控件的实现》QGroupBox是Qt框架中一个非常有用的控件,它主要用于组织和管理一组相关的控件,本文主要介绍了Qt中QGroupBox控件的实现,具有一定的参考价值,感兴趣... 目录引言一、基本属性二、常用方法2.1 构造函数 2.2 设置标题2.3 设置复选框模式2.4 是否

C++使用printf语句实现进制转换的示例代码

《C++使用printf语句实现进制转换的示例代码》在C语言中,printf函数可以直接实现部分进制转换功能,通过格式说明符(formatspecifier)快速输出不同进制的数值,下面给大家分享C+... 目录一、printf 原生支持的进制转换1. 十进制、八进制、十六进制转换2. 显示进制前缀3. 指

springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法

《springboot整合阿里云百炼DeepSeek实现sse流式打印的操作方法》:本文主要介绍springboot整合阿里云百炼DeepSeek实现sse流式打印,本文给大家介绍的非常详细,对大... 目录1.开通阿里云百炼,获取到key2.新建SpringBoot项目3.工具类4.启动类5.测试类6.测

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

SpringBoot集成Milvus实现数据增删改查功能

《SpringBoot集成Milvus实现数据增删改查功能》milvus支持的语言比较多,支持python,Java,Go,node等开发语言,本文主要介绍如何使用Java语言,采用springboo... 目录1、Milvus基本概念2、添加maven依赖3、配置yml文件4、创建MilvusClient

python logging模块详解及其日志定时清理方式

《pythonlogging模块详解及其日志定时清理方式》:本文主要介绍pythonlogging模块详解及其日志定时清理方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录python logging模块及日志定时清理1.创建logger对象2.logging.basicCo

JS+HTML实现在线图片水印添加工具

《JS+HTML实现在线图片水印添加工具》在社交媒体和内容创作日益频繁的今天,如何保护原创内容、展示品牌身份成了一个不得不面对的问题,本文将实现一个完全基于HTML+CSS构建的现代化图片水印在线工具... 目录概述功能亮点使用方法技术解析延伸思考运行效果项目源码下载总结概述在社交媒体和内容创作日益频繁的