DEiT中如何处理mask数据的?与MAE的不同

2024-03-18 05:20
文章标签 mask 数据 处理 不同 deit mae

本文主要是介绍DEiT中如何处理mask数据的?与MAE的不同,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在DeiT里面,是通过mask的方式,将mask+unmasked的patches输出进ViT中,但其实在下游任务输入的patches还是和训练时patches的数量N是一致的(encoder所有的patches)。

而MAE是在encoder中只encoder未被mask的patches

通过什么方式支持的?

  • 在处理文本时,可以根据最长的句子在批次中动态padding或截断长句子
  • 而在处理图像(如使用ViT)时,可以将图像划分为大小相等的patches,数量可以根据图像的大小动态变化。

在训练阶段,部分patches被mask为0,但是处理的所有patches加起来的总长度还是一样的。被mask的位置在模型内部仍然占位,保持了输入序列的“框架”。这样,即使实际参与计算的只是部分元素,模型也能够适应在推理时使用全部元素的情况。

具体的计算步骤如下:

  1. 确定mask哪些patches
  2. 将mask的patches位置设置为0
  3. 这些被mask和未被mask的所有patches一起被输入进attention模块
  4. 将被mask的patches的注意力分数手动设置为“无穷大负数”(-inf)
  5. 这些被mask的patches的softmax值就会变为0,也就意味着这些patches并未参与注意力的计算
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MaskedSelfAttention(nn.Module):def __init__(self, embed_size):super(MaskedSelfAttention, self).__init__()self.query = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.value = nn.Linear(embed_size, embed_size)def forward(self, x, mask=None):Q = self.query(x)K = self.key(x)V = self.value(x)# 计算自注意力得分attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))# 将mask值为0的位置在attention_scores中设置为一个非常大的负数attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))# 使得这些位置的softmax结果接近0attention_weights = F.softmax(attention_scores, dim=-1)# 算最终的注意力加权和output = torch.matmul(attention_weights, V)return output# 假设嵌入大小为512
embed_size = 512
# 创建一个mask,假设我们有4个patches,我们想要mask掉第2个和第4个patches
mask = torch.tensor([[1, 0, 1, 0]])
# 扩展mask维度以适应attention_scores的形状(假设批大小为1,序列长度为4),mask需要与attention_scores形状匹配,即(batch_size, 1, 1, seq_length)
mask = mask.unsqueeze(1).unsqueeze(2)# 初始化模型和数据
sa = MaskedSelfAttention(embed_size)
x = torch.randn(1, 4, embed_size)  # 假设有一个批大小为1,序列长度为4的输入output = sa(x, mask)
print(output)

这篇关于DEiT中如何处理mask数据的?与MAE的不同的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/821290

相关文章

大模型研发全揭秘:客服工单数据标注的完整攻略

在人工智能(AI)领域,数据标注是模型训练过程中至关重要的一步。无论你是新手还是有经验的从业者,掌握数据标注的技术细节和常见问题的解决方案都能为你的AI项目增添不少价值。在电信运营商的客服系统中,工单数据是客户问题和解决方案的重要记录。通过对这些工单数据进行有效标注,不仅能够帮助提升客服自动化系统的智能化水平,还能优化客户服务流程,提高客户满意度。本文将详细介绍如何在电信运营商客服工单的背景下进行

基于MySQL Binlog的Elasticsearch数据同步实践

一、为什么要做 随着马蜂窝的逐渐发展,我们的业务数据越来越多,单纯使用 MySQL 已经不能满足我们的数据查询需求,例如对于商品、订单等数据的多维度检索。 使用 Elasticsearch 存储业务数据可以很好的解决我们业务中的搜索需求。而数据进行异构存储后,随之而来的就是数据同步的问题。 二、现有方法及问题 对于数据同步,我们目前的解决方案是建立数据中间表。把需要检索的业务数据,统一放到一张M

关于数据埋点,你需要了解这些基本知识

产品汪每天都在和数据打交道,你知道数据来自哪里吗? 移动app端内的用户行为数据大多来自埋点,了解一些埋点知识,能和数据分析师、技术侃大山,参与到前期的数据采集,更重要是让最终的埋点数据能为我所用,否则可怜巴巴等上几个月是常有的事。   埋点类型 根据埋点方式,可以区分为: 手动埋点半自动埋点全自动埋点 秉承“任何事物都有两面性”的道理:自动程度高的,能解决通用统计,便于统一化管理,但个性化定

无人叉车3d激光slam多房间建图定位异常处理方案-墙体画线地图切分方案

墙体画线地图切分方案 针对问题:墙体两侧特征混淆误匹配,导致建图和定位偏差,表现为过门跳变、外月台走歪等 ·解决思路:预期的根治方案IGICP需要较长时间完成上线,先使用切分地图的工程化方案,即墙体两侧切分为不同地图,在某一侧只使用该侧地图进行定位 方案思路 切分原理:切分地图基于关键帧位置,而非点云。 理论基础:光照是直线的,一帧点云必定只能照射到墙的一侧,无法同时照到两侧实践考虑:关

使用SecondaryNameNode恢复NameNode的数据

1)需求: NameNode进程挂了并且存储的数据也丢失了,如何恢复NameNode 此种方式恢复的数据可能存在小部分数据的丢失。 2)故障模拟 (1)kill -9 NameNode进程 [lytfly@hadoop102 current]$ kill -9 19886 (2)删除NameNode存储的数据(/opt/module/hadoop-3.1.4/data/tmp/dfs/na

异构存储(冷热数据分离)

异构存储主要解决不同的数据,存储在不同类型的硬盘中,达到最佳性能的问题。 异构存储Shell操作 (1)查看当前有哪些存储策略可以用 [lytfly@hadoop102 hadoop-3.1.4]$ hdfs storagepolicies -listPolicies (2)为指定路径(数据存储目录)设置指定的存储策略 hdfs storagepolicies -setStoragePo

Hadoop集群数据均衡之磁盘间数据均衡

生产环境,由于硬盘空间不足,往往需要增加一块硬盘。刚加载的硬盘没有数据时,可以执行磁盘数据均衡命令。(Hadoop3.x新特性) plan后面带的节点的名字必须是已经存在的,并且是需要均衡的节点。 如果节点不存在,会报如下错误: 如果节点只有一个硬盘的话,不会创建均衡计划: (1)生成均衡计划 hdfs diskbalancer -plan hadoop102 (2)执行均衡计划 hd

2. c#从不同cs的文件调用函数

1.文件目录如下: 2. Program.cs文件的主函数如下 using System;using System.Collections.Generic;using System.Linq;using System.Threading.Tasks;using System.Windows.Forms;namespace datasAnalysis{internal static

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

uva 10061 How many zero's and how many digits ?(不同进制阶乘末尾几个0)+poj 1401

题意是求在base进制下的 n!的结果有几位数,末尾有几个0。 想起刚开始的时候做的一道10进制下的n阶乘末尾有几个零,以及之前有做过的一道n阶乘的位数。 当时都是在10进制下的。 10进制下的做法是: 1. n阶位数:直接 lg(n!)就是得数的位数。 2. n阶末尾0的个数:由于2 * 5 将会在得数中以0的形式存在,所以计算2或者计算5,由于因子中出现5必然出现2,所以直接一