知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

本文主要是介绍知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

这篇关于知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/456622

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Java架构师知识体认识

源码分析 常用设计模式 Proxy代理模式Factory工厂模式Singleton单例模式Delegate委派模式Strategy策略模式Prototype原型模式Template模板模式 Spring5 beans 接口实例化代理Bean操作 Context Ioc容器设计原理及高级特性Aop设计原理Factorybean与Beanfactory Transaction 声明式事物

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

sqlite3 相关知识

WAL 模式 VS 回滚模式 特性WAL 模式回滚模式(Rollback Journal)定义使用写前日志来记录变更。使用回滚日志来记录事务的所有修改。特点更高的并发性和性能;支持多读者和单写者。支持安全的事务回滚,但并发性较低。性能写入性能更好,尤其是读多写少的场景。写操作会造成较大的性能开销,尤其是在事务开始时。写入流程数据首先写入 WAL 文件,然后才从 WAL 刷新到主数据库。数据在开始

【C++】_list常用方法解析及模拟实现

相信自己的力量,只要对自己始终保持信心,尽自己最大努力去完成任何事,就算事情最终结果是失败了,努力了也不留遗憾。💓💓💓 目录   ✨说在前面 🍋知识点一:什么是list? •🌰1.list的定义 •🌰2.list的基本特性 •🌰3.常用接口介绍 🍋知识点二:list常用接口 •🌰1.默认成员函数 🔥构造函数(⭐) 🔥析构函数 •🌰2.list对象

活用c4d官方开发文档查询代码

当你问AI助手比如豆包,如何用python禁止掉xpresso标签时候,它会提示到 这时候要用到两个东西。https://developers.maxon.net/论坛搜索和开发文档 比如这里我就在官方找到正确的id描述 然后我就把参数标签换过来

浅谈主机加固,六种有效的主机加固方法

在数字化时代,数据的价值不言而喻,但随之而来的安全威胁也日益严峻。从勒索病毒到内部泄露,企业的数据安全面临着前所未有的挑战。为了应对这些挑战,一种全新的主机加固解决方案应运而生。 MCK主机加固解决方案,采用先进的安全容器中间件技术,构建起一套内核级的纵深立体防护体系。这一体系突破了传统安全防护的局限,即使在管理员权限被恶意利用的情况下,也能确保服务器的安全稳定运行。 普适主机加固措施:

webm怎么转换成mp4?这几种方法超多人在用!

webm怎么转换成mp4?WebM作为一种新兴的视频编码格式,近年来逐渐进入大众视野,其背后承载着诸多优势,但同时也伴随着不容忽视的局限性,首要挑战在于其兼容性边界,尽管WebM已广泛适应于众多网站与软件平台,但在特定应用环境或老旧设备上,其兼容难题依旧凸显,为用户体验带来不便,再者,WebM格式的非普适性也体现在编辑流程上,由于它并非行业内的通用标准,编辑过程中可能会遭遇格式不兼容的障碍,导致操

poj 1258 Agri-Net(最小生成树模板代码)

感觉用这题来当模板更适合。 题意就是给你邻接矩阵求最小生成树啦。~ prim代码:效率很高。172k...0ms。 #include<stdio.h>#include<algorithm>using namespace std;const int MaxN = 101;const int INF = 0x3f3f3f3f;int g[MaxN][MaxN];int n