transfomer中Decoder和Encoder的base_layer的源码实现

2024-01-16 11:28

本文主要是介绍transfomer中Decoder和Encoder的base_layer的源码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

简介

Encoder和Decoder共同组成transfomer,分别对应图中左右浅绿色框内的部分.
在这里插入图片描述
Encoder:
目的:将输入的特征图转换为一系列自注意力的输出。
工作原理:首先,通过卷积神经网络(CNN)提取输入图像的特征。然后,这些特征通过一系列自注意力的变换层进行处理,每个变换层都会将特征映射进行编码并产生一个新的特征映射。这个过程旨在捕捉图像中的空间和通道依赖关系。
作用:通过处理输入特征,提取图像特征并进行自注意力操作,为后续的目标检测任务提供必要的特征信息。
Decoder:
目的:接受Encoder的输出,并生成对目标类别和边界框的预测。
工作原理:首先,它接收Encoder的输出,然后使用一系列解码器层对目标对象之间的关系和全局图像上下文进行推理。这些解码器层将最终的目标类别和边界框的预测作为输出。
作用:基于Encoder的输出和全局上下文信息,生成目标类别和边界框的预测结果。
总结:Encoder就是特征提取类似卷积;Decoder用于生成box,类似head

源码实现:

Encoder 通常是6个encoder_layer组成,Decoder 通常是6个decoder_layer组成
我实现了核心的BaseTransformerLayer层,可以用来定义encoder_layer和decoder_layer

具体源码及其注释如下,配好环境可直接运行(运行依赖于上一个博客的代码):

import torch
from torch import nn
from ZMultiheadAttention import MultiheadAttention  # 来自上一次写的attensionclass FFN(nn.Module):def __init__(self,embed_dim=256,feedforward_channels=1024,act_cfg='ReLU',ffn_drop=0.,):super(FFN, self).__init__()self.l1 = nn.Linear(in_features=embed_dim, out_features=feedforward_channels)if act_cfg == 'ReLU':self.act1 = nn.ReLU(inplace=True)else:self.act1 = nn.SiLU(inplace=True)self.d1 = nn.Dropout(p=ffn_drop)self.l2 = nn.Linear(in_features=feedforward_channels, out_features=embed_dim)self.d2 = nn.Dropout(p=ffn_drop)def forward(self, x):tmp = self.d1(self.act1(self.l1(x)))tmp = self.d2(self.l2(tmp))x = tmp + xreturn x# transfomer encode和decode的最小循环单元,用于打包self_attention或者cross_attention
class BaseTransformerLayer(nn.Module):def __init__(self,attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=128, act_cfg='ReLU', ffn_drop=0.),operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')):super(BaseTransformerLayer, self).__init__()self.attentions = nn.ModuleList()# 搭建att层for attn_cfg in attn_cfgs:self.attentions.append(MultiheadAttention(**attn_cfg))self.embed_dims = self.attentions[0].embed_dim# 统计norm数量 并搭建self.norms = nn.ModuleList()num_norms = operation_order.count('norm')for _ in range(num_norms):self.norms.append(nn.LayerNorm(normalized_shape=self.embed_dims))# 统计ffn数量 并搭建self.ffns = nn.ModuleList()self.ffns.append(FFN(**fnn_cfg))self.operation_order = operation_orderdef forward(self, query, key=None, value=None, query_pos=None, key_pos=None):attn_index = 0norm_index = 0ffn_index = 0for order in self.operation_order:if order == 'self_attn':temp_key = temp_value = query  # 不用担心三个值一样,在attention里面会重映射qkvquery, attention = self.attentions[attn_index](query,temp_key,temp_value,query_pos=query_pos,key_pos=query_pos)attn_index += 1elif order == 'cross_attn':query, attention = self.attentions[attn_index](query,key,value,query_pos=query_pos,key_pos=key_pos)attn_index += 1elif order == 'norm':query = self.norms[norm_index](query)norm_index += 1elif order == 'ffn':query = self.ffns[ffn_index](query)ffn_index += 1return queryif __name__ == '__main__':query = torch.rand(size=(10, 2, 64))key = torch.rand(size=(5, 2, 64))value = torch.rand(size=(5, 2, 64))query_pos = torch.rand(size=(10, 2, 64))key_pos = torch.rand(size=(5, 2, 64))# encoder 通常是6个encoder_layer组成 每个encoder_layer['self_attn', 'norm', 'ffn', 'norm']encoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'ffn', 'norm'))encoder_layer_output = encoder_layer(query=query, query_pos=query_pos, key_pos=key_pos)# decoder 通常是6个decoder_layer组成 每个decoder_layer['self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm']decoder_layer = BaseTransformerLayer(attn_cfgs=[dict(embed_dim=64, num_heads=4), dict(embed_dim=64, num_heads=4)],fnn_cfg=dict(embed_dim=64, feedforward_channels=1024, act_cfg='ReLU',ffn_drop=0.),operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))decoder_layer_output = decoder_layer(query=query, key=key, value=value, query_pos=query_pos, key_pos=key_pos)pass

具体流程说明:

Encoder 通常是6个encoder_layer组成,每个encoder_layer[‘self_attn’, ‘norm’, ‘ffn’, ‘norm’]
Decoder 通常是6个decoder_layer组成,每个decoder_layer[‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’]
按照以上方式搭建网络即可
其中norm为LayerNorm,在样本内部进行归一化。

这篇关于transfomer中Decoder和Encoder的base_layer的源码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/612428

相关文章

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

JAVA智听未来一站式有声阅读平台听书系统小程序源码

智听未来,一站式有声阅读平台听书系统 🌟&nbsp;开篇:遇见未来,从“智听”开始 在这个快节奏的时代,你是否渴望在忙碌的间隙,找到一片属于自己的宁静角落?是否梦想着能随时随地,沉浸在知识的海洋,或是故事的奇幻世界里?今天,就让我带你一起探索“智听未来”——这一站式有声阅读平台听书系统,它正悄悄改变着我们的阅读方式,让未来触手可及! 📚&nbsp;第一站:海量资源,应有尽有 走进“智听

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略

Kubernetes PodSecurityPolicy:PSP能实现的5种主要安全策略 1. 特权模式限制2. 宿主机资源隔离3. 用户和组管理4. 权限提升控制5. SELinux配置 💖The Begin💖点点关注,收藏不迷路💖 Kubernetes的PodSecurityPolicy(PSP)是一个关键的安全特性,它在Pod创建之前实施安全策略,确保P

Java ArrayList扩容机制 (源码解读)

结论:初始长度为10,若所需长度小于1.5倍原长度,则按照1.5倍扩容。若不够用则按照所需长度扩容。 一. 明确类内部重要变量含义         1:数组默认长度         2:这是一个共享的空数组实例,用于明确创建长度为0时的ArrayList ,比如通过 new ArrayList<>(0),ArrayList 内部的数组 elementData 会指向这个 EMPTY_EL

如何在Visual Studio中调试.NET源码

今天偶然在看别人代码时,发现在他的代码里使用了Any判断List<T>是否为空。 我一般的做法是先判断是否为null,再判断Count。 看了一下Count的源码如下: 1 [__DynamicallyInvokable]2 public int Count3 {4 [__DynamicallyInvokable]5 get