知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

本文主要是介绍知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a - a.mean(1).unsqueeze(1),b - b.mean(1).unsqueeze(1), eps)def inter_class_relation(soft_student_outputs, soft_teacher_outputs):return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()def intra_class_relation(soft_student_outputs, soft_teacher_outputs):return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))class DIST(nn.Module):def __init__(self, beta=1.0, gamma=1.0, temp=1.0):super(DIST, self).__init__()self.beta = betaself.gamma = gammaself.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = (student_preds / self.temp).softmax(dim=1)soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)kd_loss = self.beta * inter_loss + self.gamma * intra_lossreturn kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):def __init__(self, temp=1.0):super(KLDiv, self).__init__()self.temp = tempdef forward(self, student_preds, teacher_preds, **kwargs):soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()kd_loss *= self.temp ** 2return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as Fdef dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):gt_mask = _get_gt_mask(logits_student, target)other_mask = _get_other_mask(logits_student, target)pred_student = F.softmax(logits_student / temperature, dim=1)pred_teacher = F.softmax(logits_teacher / temperature, dim=1)pred_student = cat_mask(pred_student, gt_mask, other_mask)pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)log_pred_student = torch.log(pred_student)tckd_loss = (F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')* (temperature ** 2))pred_teacher_part2 = F.softmax(logits_teacher / temperature - 1000.0 * gt_mask, dim=1)log_pred_student_part2 = F.log_softmax(logits_student / temperature - 1000.0 * gt_mask, dim=1)nckd_loss = (F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')* (temperature ** 2))return alpha * tckd_loss + beta * nckd_lossdef _get_gt_mask(logits, target):target = target.reshape(-1)mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()return maskdef _get_other_mask(logits, target):target = target.reshape(-1)mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()return maskdef cat_mask(t, mask1, mask2):t1 = (t * mask1).sum(dim=1, keepdims=True)t2 = (t * mask2).sum(1, keepdims=True)rt = torch.cat([t1, t2], dim=1)return rtclass DKD(nn.Module):def __init__(self, alpha=1., beta=2., temperature=1.):super(DKD, self).__init__()self.alpha = alphaself.beta = betaself.temperature = temperaturedef forward(self, z_s, z_t, **kwargs):target = kwargs['target']if len(target.shape) == 2:  # mixup / smoothingtarget = target.max(1)[1]kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

这篇关于知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/456622

相关文章

Nginx 访问控制的多种方法

《Nginx访问控制的多种方法》本文系统介绍了Nginx实现Web访问控制的多种方法,包括IP黑白名单、路径/方法/参数控制、HTTP基本认证、防盗链机制、客户端证书校验、限速限流、地理位置控制等基... 目录一、IP 白名单与黑名单1. 允许/拒绝指定IP2. 全局黑名单二、基于路径、方法、参数的访问控制

Nginx服务器部署详细代码实例

《Nginx服务器部署详细代码实例》Nginx是一个高性能的HTTP和反向代理web服务器,同时也提供了IMAP/POP3/SMTP服务,:本文主要介绍Nginx服务器部署的相关资料,文中通过代码... 目录Nginx 服务器SSL/TLS 配置动态脚本反向代理总结Nginx 服务器Nginx是一个‌高性

Python中Request的安装以及简单的使用方法图文教程

《Python中Request的安装以及简单的使用方法图文教程》python里的request库经常被用于进行网络爬虫,想要学习网络爬虫的同学必须得安装request这个第三方库,:本文主要介绍P... 目录1.Requests 安装cmd 窗口安装为pycharm安装在pycharm设置中为项目安装req

nginx跨域访问配置的几种方法实现

《nginx跨域访问配置的几种方法实现》本文详细介绍了Nginx跨域配置方法,包括基本配置、只允许指定域名、携带Cookie的跨域、动态设置允许的Origin、支持不同路径的跨域控制、静态资源跨域以及... 目录一、基本跨域配置二、只允许指定域名跨域三、完整示例四、配置后重载 nginx五、注意事项六、支持

MySQL查看表的历史SQL的几种实现方法

《MySQL查看表的历史SQL的几种实现方法》:本文主要介绍多种查看MySQL表历史SQL的方法,包括通用查询日志、慢查询日志、performance_schema、binlog、第三方工具等,并... 目录mysql 查看某张表的历史SQL1.查看MySQL通用查询日志(需提前开启)2.查看慢查询日志3.

MySQL底层文件的查看和修改方法

《MySQL底层文件的查看和修改方法》MySQL底层文件分为文本类(可安全查看/修改)和二进制类(禁止手动操作),以下按「查看方法、修改方法、风险管控三部分详细说明,所有操作均以Linux环境为例,需... 目录引言一、mysql 底层文件的查看方法1. 先定位核心文件路径(基础前提)2. 文本类文件(可直

Java实现字符串大小写转换的常用方法

《Java实现字符串大小写转换的常用方法》在Java中,字符串大小写转换是文本处理的核心操作之一,Java提供了多种灵活的方式来实现大小写转换,适用于不同场景和需求,本文将全面解析大小写转换的各种方法... 目录前言核心转换方法1.String类的基础方法2. 考虑区域设置的转换3. 字符级别的转换高级转换

使用Python实现局域网远程监控电脑屏幕的方法

《使用Python实现局域网远程监控电脑屏幕的方法》文章介绍了两种使用Python在局域网内实现远程监控电脑屏幕的方法,方法一使用mss和socket,方法二使用PyAutoGUI和Flask,每种方... 目录方法一:使用mss和socket实现屏幕共享服务端(被监控端)客户端(监控端)方法二:使用PyA

HTML5的input标签的`type`属性值详解和代码示例

《HTML5的input标签的`type`属性值详解和代码示例》HTML5的`input`标签提供了多种`type`属性值,用于创建不同类型的输入控件,满足用户输入的多样化需求,从文本输入、密码输入、... 目录一、引言二、文本类输入类型2.1 text2.2 password2.3 textarea(严格

检查 Nginx 是否启动的几种方法

《检查Nginx是否启动的几种方法》本文主要介绍了检查Nginx是否启动的几种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学... 目录1. 使用 systemctl 命令(推荐)2. 使用 service 命令3. 检查进程是否存在4