带掩码的自编码器MAE详解和代码实现

2024-01-16 10:50

本文主要是介绍带掩码的自编码器MAE详解和代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

监督学习是训练机器学习模型的传统方法,它在训练时每一个观察到的数据都需要有标注好的标签。如果我们有一种训练机器学习模型的方法不需要收集标签,会怎么样?如果我们从收集的相同数据中提取标签呢?这种类型的学习算法被称为自监督学习。这种方法在自然语言处理中工作得很好。一个例子是BERT¹,谷歌自2019年以来一直在其搜索引擎中使用BERT¹。不幸的是,对于计算机视觉来说,情况并非如此。

Facebook AI的kaiming大神等人提出了一种带掩码自编码器(MAE)²,它基于(ViT)³架构。他们的方法在ImageNet上的表现要好于从零开始训练的VIT。在本文中,我们将深入研究他们的方法,并了解如何在代码中实现它。

带掩码自编码器(MAE)

对输入图像的patches进行随机掩码,然后重建缺失的像素。MAE基于两个核心设计。首先,开发了一个非对称的编码器-解码器架构,其中编码器仅对可见的patches子集(没有掩码的tokens)进行操作,同时还有一个轻量级的解码器,可以从潜在表示和掩码tokens重建原始图像。其次,发现对输入图像进行高比例的掩码,例如75%,会产生有意义的自监督任务。将这两种设计结合起来,能够高效地训练大型模型:加快模型训练速度(3倍甚至更多)并提高精度。

此阶段称为预训练,因为 MAE 模型稍后将用于下游任务,例如图像分类。 模型在pretext上的表现在自监督中并不重要, 这些任务的重点是让模型学习一个预期包含良好语义的中间表示。 在预训练阶段之后,解码器将被多层感知器 (MLP) 头或线性层取代,作为分类器输出对下游任务的预测。

模型架构

编码器

编码器是 ViT。 它接受张量形状为 (batch_size, RGB_channels, height, width) 的图像。 通过执行线性投影为每个Patch获得嵌入, 这是通过 2D 卷积层来完成。 然后张量在最后一个维度被展平(压扁),变成 (batch_size, encoder_embed_dim, num_visible_patches),并 转置为形状(batch_size、num_visible_patches、encoder_embed_dim)的张量。

class PatchEmbed(nn.Module):""" Image to Patch Embedding """def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):super().__init__()self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)def forward(self, x, **kwargs):B, C, H, W = x.shapeassert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."x = self.proj(x).flatten(2).transpose(1, 2)return x

正如原始 Transformer 论文中提到的,位置编码添加了有关每个Patch位置的信息。 作者使用“sine-cosine”版本而不是可学习的位置嵌入。 下面的这个实现是一维版本。

def get_sinusoid_encoding_table(n_position, d_hid): def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) 

与 Transformer 类似,每个块由norm层、多头注意力模块和前馈层组成。 中间输出形状是(batch_size、num_visible_patches、encoder_embed_dim)。 多头注意力模块的代码如下:

class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None):super().__init__()self.num_heads = num_headshead_dim = attn_head_dim if attn_head_dim is not None else dim // num_headsall_head_dim = head_dim * self.num_headsself.scale = qk_scale or head_dim ** -0.5self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else Noneself.v_bias = nn.Parameter(torch.zeros(all_head_dim)) if qkv_bias else Noneself.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(all_head_dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, N, C = x.shapeqkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) if self.q_bias is not None else Noneqkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1)).softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, -1)x = self.proj_drop(self.proj(x))return x

Transformer 模块的代码如下:

class Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_head_dim=None):super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim)self.norm2 = norm_layer(dim)self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), act_layer(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(attn_drop))def forward(self, x):x = x + self.attn(self.norm1(x))x = x + self.mlp(self.norm2(x))return x

这部分仅用于下游任务的微调。 论文的模型遵循 ViT 架构,该架构具有用于分类的类令牌(patch)。 因此,他们添加了一个虚拟令牌,但是论文中也说到他们的方法在没有它的情况下也可以运行良好,因为对其他令牌执行了平均池化操作。 在这里也包含了实现的平均池化版本。 之后,添加一个线性层作为分类器。 最终的张量形状是 (batch_size, num_classes)。

综上所述,编码器实现如下:

class Encoder(nn.Module)def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=0, **block_kwargs):super().__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models# Patch embeddingself.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)num_patches = self.patch_embed.num_patches# Positional encodingself.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)# Transformer blocksself.blocks = nn.ModuleList([Block(**block_kwargs) for i in range(depth)])  # various arguments are not shown here for brevity purposesself.norm =  norm_layer(embed_dim)# Classifier (for fine-tuning only)self.fc_norm = norm_layer(embed_dim)self.head = nn.Linear(embed_dim, num_classes)def forward(self, x, mask):x = self.patch_embed(x)x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()B, _, C = x.shapeif mask is not None:  # for pretraining onlyx = x[~mask].reshape(B, -1, C) # ~mask means visiblefor blk in self.blocks:x = blk(x)x = self.norm(x)if self.num_classes > 0:  # for fine-tuning onlyx = self.fc_norm(x.mean(1))  # average poolingx = self.head(x)return x

解码器

与编码器类似,解码器由一系列transformer 块组成。 在解码器的末端,有一个由norm层和前馈层组成的分类器。 输入张量的形状为 batch_size, num_patches,decoder_embed_dim) 而最终输出张量的形状为 (batch_size, num_patches, 3 * patch_size ** 2)。

class Decoder(nn.Module):def __init__(self, patch_size=16, embed_dim=768, norm_layer=nn.LayerNorm, num_classes=768, **block_kwargs):super().__init__()self.num_classes = num_classesassert num_classes == 3 * patch_size ** 2self.num_features = self.embed_dim = embed_dimself.patch_size = patch_sizeself.blocks = nn.ModuleList([Block(**block_kwargs) for i in range(depth)])  # various arguments are not shown here for brevity purposesself.norm =  norm_layer(embed_dim)self.head = nn.Linear(embed_dim, num_classes)def forward(self, x, return_token_num):for blk in self.blocks:x = blk(x)if return_token_num > 0:x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixelselse:x = self.head(self.norm(x))return x

把所有东西放在一起——MAE架构

MAE 用于对掩码图像进行预训练。首先,屏蔽的输入被发送到编码器。然后,它们被传递到前馈层以更改嵌入维度以匹配解码器。 在传递给解码器之前,被掩码的Patch被输入进去。 位置编码再次应用于完整的图像块集,包括可见的和被掩码遮盖的。

在论文中,作者对包含所有Patch的列表进行了打乱,以便正确插入Patch的掩码。 这部分在本篇文章中没有完成,因为在 PyTorch 上实现并不简单。所以这里使用的是位置编码在被添加到Patch之前被相应地打乱的做法。

class MAE(nn.Module):def __init__(self, ...):  # various arguments are not shown here for brevity purposessuper().__init__()self.encoder = Encoder(img_size, patch_size, in_chans, embed_dim, norm_layer, num_classes=0, **block_kwargs)self.decoder = Decoder(patch_size, embed_dim, norm_layer, num_classes, **block_kwargs)self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=False)self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))self.pos_embed = get_sinusoid_encoding_table(self.encoder.patch_embed.num_patches, decoder_embed_dim)def forward(self, x, mask):x_vis = self.encoder(x, mask)x_vis = self.encoder_to_decoder(x_vis)B, N, C = x_vis.shapeexpand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)x_full = torch.cat([x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]return x

训练过程

对于自监督预训练,论文发现简单的逐像素平均绝对损失作为目标函数效果很好。 并且他们使用的数据集是 ImageNet-1K 训练集。

在下游的微调阶段,解码器被移除,编码器在相同的数据集上进行训练。 数据与预训练略有不同,因为编码器现在使用完整的图像块集(没有屏蔽)。 因此,现在的Patch数量与预训练阶段不同。

如果您你知道用于预训练的模型是否仍然可以用于微调,答案是肯定的。 编码器主要由注意力模块、norm层和前馈层组成。 要检查Patch数量(索引 1)的变化是否影响前向传递,我们需要查看每一层的参数张量的形状。

  • norm层中的参数的形状为(batch, 1, encoder_embed_dim)。 它可以在前向传播期间沿着补丁维度(索引 1)进行广播,因此它不依赖于补丁维度的大小。
  • 前馈层有一个形状为(in_channels, out_channels)的权重矩阵和一个形状为(out_channels,)的偏置矩阵,两者都不依赖于patch的数量。
  • 注意力模块本质上执行一系列线性投影。 因此,出于同样的原因,patch的数量也不会影响参数张量的形状。

由于并行处理允许将数据分批输入,所以批处理中的Patch数量是需要保持一致的。

结果

让我们看看原始论文中报道的预训练阶段的重建图像。看起来MAE在重建图像方面做得很好,即使80%的像素被遮蔽了。

ImageNet验证图像的示例结果。从左到右:遮蔽图像、重建图像、真实图像。掩蔽率为80%。

MAE 在微调的下游任务上也表现良好,例如 ImageNet-1K 数据集上的图像分类。 与监督方式相比,在使用 MAE 预训练进行训练时比使用的基线 ViT-Large 实际上表现更好。

论文中还包括对下游任务和各种消融研究的迁移学习实验的基准结果。有兴趣的可以再看看原论文。

讨论

如果您熟悉 BERT,您可能会注意到 BERT 和 MAE 的方法之间的相似之处。在 BERT 的预训练中,我们遮蔽了一部分文本,模型的任务是预测它们。此外,由于我们现在使用的是基于 Transformer 的架构,因此说这种方法在视觉上与 BERT 等效也不是不合适的。

但是论文中说这种方法早于 BERT。例如,过去对图像自监督的尝试使用堆叠去噪自编码器和图像修复作为pretext task。 MAE 本身也使用自动编码器作为模型和类似于图像修复的pretext task。

如果是这样的话,是什么让 MAE 工作比以前模型好呢?我认为关键在于 ViT 架构。在他们的论文中,作者提到卷积神经网络在将掩码标记和位置嵌入等“指标”集成到其中时存在问题,而 ViT 解决了这种架构差距。如果是这样,那么我们将看到在自然语言处理中开发的另一个想法在计算机视觉中成功实现。之前是attention机制,然后Transformer的概念以Vision Transformers的形式借用到计算机视觉中,现在是整个BERT预训练过程。

结论

我对未来自监督的视觉必须提供的东西感到兴奋。鉴于 BERT 在自然语言处理方面的成功,像 MAE 这样的掩码建模方法将有益于计算机视觉。图像数据很容易获得,但标记它们可能很耗时。通过这种方法,人们可以通过管理比 ImageNet 大得多的数据集来扩展预训练过程,而无需担心标记。潜力是无限的。我们是否会见证计算机视觉的另一次复兴,只有时间才能证明。

引用

  1. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pretraining of deep bidirectional transformers for language understanding. In NAACL, 2019.
  2. Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross Girshick. Masked autoencoders are scalable vision learners. arXiv:2111.06377, 2021.
  3. Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In ICLR, 2021.

作者:Stephen Lau

这篇关于带掩码的自编码器MAE详解和代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/612335

相关文章

Spring Security基于数据库验证流程详解

Spring Security 校验流程图 相关解释说明(认真看哦) AbstractAuthenticationProcessingFilter 抽象类 /*** 调用 #requiresAuthentication(HttpServletRequest, HttpServletResponse) 决定是否需要进行验证操作。* 如果需要验证,则会调用 #attemptAuthentica

hdu1043(八数码问题,广搜 + hash(实现状态压缩) )

利用康拓展开将一个排列映射成一个自然数,然后就变成了普通的广搜题。 #include<iostream>#include<algorithm>#include<string>#include<stack>#include<queue>#include<map>#include<stdio.h>#include<stdlib.h>#include<ctype.h>#inclu

OpenHarmony鸿蒙开发( Beta5.0)无感配网详解

1、简介 无感配网是指在设备联网过程中无需输入热点相关账号信息,即可快速实现设备配网,是一种兼顾高效性、可靠性和安全性的配网方式。 2、配网原理 2.1 通信原理 手机和智能设备之间的信息传递,利用特有的NAN协议实现。利用手机和智能设备之间的WiFi 感知订阅、发布能力,实现了数字管家应用和设备之间的发现。在完成设备间的认证和响应后,即可发送相关配网数据。同时还支持与常规Sof

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

【Prometheus】PromQL向量匹配实现不同标签的向量数据进行运算

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全栈,前后端开发,小程序开发,人工智能,js逆向,App逆向,网络系统安全,数据分析,Django,fastapi

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

让树莓派智能语音助手实现定时提醒功能

最初的时候是想直接在rasa 的chatbot上实现,因为rasa本身是带有remindschedule模块的。不过经过一番折腾后,忽然发现,chatbot上实现的定时,语音助手不一定会有响应。因为,我目前语音助手的代码设置了长时间无应答会结束对话,这样一来,chatbot定时提醒的触发就不会被语音助手获悉。那怎么让语音助手也具有定时提醒功能呢? 我最后选择的方法是用threading.Time

Android实现任意版本设置默认的锁屏壁纸和桌面壁纸(两张壁纸可不一致)

客户有些需求需要设置默认壁纸和锁屏壁纸  在默认情况下 这两个壁纸是相同的  如果需要默认的锁屏壁纸和桌面壁纸不一样 需要额外修改 Android13实现 替换默认桌面壁纸: 将图片文件替换frameworks/base/core/res/res/drawable-nodpi/default_wallpaper.*  (注意不能是bmp格式) 替换默认锁屏壁纸: 将图片资源放入vendo

C#实战|大乐透选号器[6]:实现实时显示已选择的红蓝球数量

哈喽,你好啊,我是雷工。 关于大乐透选号器在前面已经记录了5篇笔记,这是第6篇; 接下来实现实时显示当前选中红球数量,蓝球数量; 以下为练习笔记。 01 效果演示 当选择和取消选择红球或蓝球时,在对应的位置显示实时已选择的红球、蓝球的数量; 02 标签名称 分别设置Label标签名称为:lblRedCount、lblBlueCount

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n