Deit:知识蒸馏与vit的结合 学习笔记(附代码)

2024-01-13 23:12

本文主要是介绍Deit:知识蒸馏与vit的结合 学习笔记(附代码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 论文地址:https://arxiv.org/abs/2012.12877

代码地址:GitHub - facebookresearch/deit: Official DeiT repository

1.是什么?

DeiT(Data-efficient Image Transformer)是一种用于图像分类任务的神经网络模型,它基于Transformer架构。这个模型的主要目标是在参数较少的情况下实现高效的图像分类。相比于传统的卷积神经网络(CNN),DeiT采用了Transformer的注意力机制,使其能够更好地捕捉图像中的全局关系。
以下是DeiT模型的一些关键特点和组成部分:

  1. 1.Transformer 架构: DeiT采用了Transformer的架构,这是一种自注意力机制的模型。这种架构在自然语言处理任务中取得了显著的成功,DeiT将其成功地应用于图像分类领域。
  2. 2.小模型参数: 为了提高数据效率,DeiT设计为具有相对较少的参数。这使得模型在训练和推理时需要更少的计算资源。
  3. 3.Knowledge Distillation: DeiT使用知识蒸馏(Knowledge Distillation)的方法进行训练。这意味着它通过从一个大型预训练模型中传递知识来训练,而不是从头开始训练。这有助于在资源受限的情况下实现更好的性能。
  4. 4.Patch Embedding: 与传统的卷积层不同,DeiT使用了补丁嵌入(Patch Embedding)来将图像分割成小块,然后对这些块进行变换。
  5. 5.Positional Embeddings: 由于Transformer不涉及卷积层,它需要一种处理输入序列的方式。在DeiT中,位置嵌入(Positional Embeddings)用于为模型提供输入中元素的相对位置信息。

总体而言,DeiT是一个旨在通过Transformer的优势实现图像分类的轻量级模型,适用于数据受限的情况。通过知识蒸馏和小模型参数,它在参数较少的情况下达到了令人满意的性能。

2.为什么?

Transformer的输入是一个序列(Sequence),ViT 所采用的思路是把图像分块(patches),然后把每一块视为一个向量(vector),所有的向量并在一起就成为了一个序列(Sequence),ViT 使用的数据集包括了一个巨大的包含了 300 million images的 JFT-300,这个数据集是私有的,即外部研究者无法复现实验。而且在ViT的实验中作者明确地提到:

意思是当不使用 JFT-300 大数据集时,效果不如CNN模型。也就反映出Transformer结构若想取得理想的性能和泛化能力就需要这样大的数据集。DeiT 作者通过所提出的蒸馏的训练方案,只在 Imagenet 上进行训练,就产生了一个有竞争力的无卷积 Transformer。

3.怎么样?

在 DeiT 模型中,首先需要一个强力的图像分类模型作为teacher model。然后,引入了一个 Distillation Token,然后在 self-attention layers 中跟 class token,patch token 在 Transformer 结构中不断学习。Class token的目标是跟真实的label一致,而Distillation Token是要跟teacher model预测的label一致。蒸馏过程如下图所示。

3.1知识蒸馏

知识蒸馏(Knowledge Distillation)是一种模型训练的技术,旨在通过传递一个大型教师模型的知识来训练一个小型学生模型。这个方法的目标是使得学生模型能够获得与教师模型相似的性能,同时减少学生模型的复杂性和计算成本。
以下是知识蒸馏的关键思想和步骤:

  1. 1.教师模型: 首先,有一个在任务上表现良好的大型教师模型。这个模型通常拥有更多的参数和计算能力,以便更好地捕捉任务的复杂性和结构。
  2. 2.软目标(Soft Targets): 在传统的监督学习中,模型通常以硬标签(one-hot编码的标签)作为目标进行训练。而在知识蒸馏中,使用了软目标,这是由教师模型输出的概率分布。这样的软目标包含了关于样本的更丰富信息,使得学生模型可以学到更多的任务相关知识。
  3. 3.温度参数: 软目标的概率分布可以通过温度参数进行调节。较高的温度使概率分布更平滑,有助于学生模型更好地学到教师模型的知识。
  4. 4.学生模型: 有了教师模型和软目标,接下来就是训练学生模型。学生模型通常是一个比教师模型简化的小型模型,可以在资源受限的环境中更轻松地进行推理。
  5. 5.蒸馏损失: 为了引导学生模型学习教师模型的知识,引入了蒸馏损失。这个损失函数用于比较学生模型的输出概率分布和教师模型的输出概率分布,促使学生模型模仿教师模型的行为。

知识蒸馏的优势在于,通过传递教师模型的知识,可以在小型模型上实现接近教师模型性能的效果。这对于移动设备、嵌入式系统或其他计算资源受限的环境中的部署非常有用。

具体方法:

第一步是训练Net-T;第二步是在高温 T 下,蒸馏 Net-T 的知识到 Net-S。

训练 Net-T 的过程很简单,而高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到:

Deit 中使用 Conv-Based 架构作为教师网络,以 soft 的方式将归纳偏置传递给学生模型,将局部性的假设通过蒸馏方式引入 Transformer 中,取得了不错的效果。

3.2Distillation Token

Distillation Token 和 ViT 中的 class token 一起加入 Transformer 中,和class token 一样通过 self-attention 与其它的 embedding 一起计算,并且在最后一层之后由网络输出。

而 Distillation Token 对应的这个输出的目标函数就是蒸馏损失。Distillation Token 允许模型从教师网络的输出中学习,就像在常规的蒸馏中一样,同时也作为一种对class token的补充。

3.3代码实现

class DistilledVisionTransformer(VisionTransformer):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))num_patches = self.patch_embed.num_patchesself.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()trunc_normal_(self.dist_token, std=.02)trunc_normal_(self.pos_embed, std=.02)self.head_dist.apply(self._init_weights)def forward_features(self, x):# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py# with slight modifications to add the dist_tokenB = x.shape[0]x = self.patch_embed(x)cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanksdist_token = self.dist_token.expand(B, -1, -1)x = torch.cat((cls_tokens, dist_token, x), dim=1)x = x + self.pos_embedx = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)return x[:, 0], x[:, 1]def forward(self, x):x, x_dist = self.forward_features(x)x = self.head(x)x_dist = self.head_dist(x_dist)if self.training:return x, x_distelse:# during inference, return the average of both classifier predictionsreturn (x + x_dist) / 2

参考:ViT、Deit这类视觉transformer是如何处理变长序列输入的?

DeiT:使用Attention蒸馏Transformer

这篇关于Deit:知识蒸馏与vit的结合 学习笔记(附代码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/603121

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Java架构师知识体认识

源码分析 常用设计模式 Proxy代理模式Factory工厂模式Singleton单例模式Delegate委派模式Strategy策略模式Prototype原型模式Template模板模式 Spring5 beans 接口实例化代理Bean操作 Context Ioc容器设计原理及高级特性Aop设计原理Factorybean与Beanfactory Transaction 声明式事物

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

sqlite3 相关知识

WAL 模式 VS 回滚模式 特性WAL 模式回滚模式(Rollback Journal)定义使用写前日志来记录变更。使用回滚日志来记录事务的所有修改。特点更高的并发性和性能;支持多读者和单写者。支持安全的事务回滚,但并发性较低。性能写入性能更好,尤其是读多写少的场景。写操作会造成较大的性能开销,尤其是在事务开始时。写入流程数据首先写入 WAL 文件,然后才从 WAL 刷新到主数据库。数据在开始

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss