pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构)

本文主要是介绍pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

        前面的博文读论文:Restormer: Efficient Transformer for High-Resolution Image Restoration 介绍了 Restormer 网络结构的网络技术特点,本文用 pytorch 实现其中的主要网络结构模块。

1. MDTA(Multi-Dconv Head Transposed Attention:多头注意力机制

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):def __init__(self, dim, num_heads, bias):super(Attention, self).__init__()self.num_heads = num_heads  # 注意力头的个数self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))  # 可学习系数# 1*1 升维self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)# 3*3 分组卷积self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)# 1*1 卷积self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def forward(self, x):b,c,h,w = x.shape  # 输入的结构 batch 数,通道数和高宽qkv = self.qkv_dwconv(self.qkv(x))q,k,v = qkv.chunk(3, dim=1)  #  第 1 个维度方向切分成 3 块# 改变 q, k, v 的结构为 b head c (h w),将每个二维 plane 展平q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)  # C 维度标准化,这里的 C 与通道维度略有不同k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperature # @ 是矩阵乘attn = attn.softmax(dim=-1)out = (attn @ v)  # 注意力图(严格来说不算图)# 将展平后的注意力图恢复out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)# 真正的注意力图out = self.project_out(out)return out

2. GDFN( Gated-Dconv Feed-Forward Network) 

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):def __init__(self, dim, ffn_expansion_factor, bias):super(FeedForward, self).__init__()# 隐藏层特征维度等于输入维度乘以扩张因子hidden_features = int(dim*ffn_expansion_factor)# 1*1 升维self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)# 3*3 分组卷积self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)# 1*1 降维self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)def forward(self, x):x = self.project_in(x)x1, x2 = self.dwconv(x).chunk(2, dim=1)  # 第 1 个维度方向切分成 2 块x = F.gelu(x1) * x2  # gelu 相当于 relu+dropoutx = self.project_out(x)return x

3. TransformerBlock

## 就是标准的 Transformer 架构
class TransformerBlock(nn.Module):def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):super(TransformerBlock, self).__init__()self.norm1 = LayerNorm(dim, LayerNorm_type)  # 层标准化self.attn = Attention(dim, num_heads, bias)  # 自注意力self.norm2 = LayerNorm(dim, LayerNorm_type)  # 层表转化self.ffn = FeedForward(dim, ffn_expansion_factor, bias)  # FFNdef forward(self, x):x = x + self.attn(self.norm1(x))  # 残差x = x + self.ffn(self.norm2(x))  # 残差return x

4. 测试样例

model = Restormer()
print(model)  # 打印网络结构x = torch.randn((1, 3, 512, 512))  #随机生成输入图像
x = model(x)  # 送入网络
print(x.shape) # 打印网络输入的图像结构

这篇关于pytorch 实现 Restormer 主要模块(多头通道自注意力机制和门控制结构)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/543179

相关文章

JVM 的类初始化机制

前言 当你在 Java 程序中new对象时,有没有考虑过 JVM 是如何把静态的字节码(byte code)转化为运行时对象的呢,这个问题看似简单,但清楚的同学相信也不会太多,这篇文章首先介绍 JVM 类初始化的机制,然后给出几个易出错的实例来分析,帮助大家更好理解这个知识点。 JVM 将字节码转化为运行时对象分为三个阶段,分别是:loading 、Linking、initialization

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过滤推荐功能2.基于用户的协同过滤推荐功能 前言     在信息过载的时代,推荐系统成为连接用户与内容的桥梁。本文聚焦于

便携式气象仪器的主要特点

TH-BQX9】便携式气象仪器,也称为便携式气象仪或便携式自动气象站,是一款高度集成、低功耗、可快速安装、便于野外监测使用的高精度自动气象观测设备。以下是关于便携式气象仪器的详细介绍:   主要特点   高精度与多功能:便携式气象仪器能够采集多种气象参数,包括但不限于风速、风向、温度、湿度、气压等,部分高级型号还能监测雨量和辐射等。数据采集与存储:配备微电脑气象数据采集仪,具有实时时钟、数据存

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount