首页
Python
Java
前端
数据库
Linux
Chatgpt专题
开发者工具箱
kldiv专题
知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)
有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型: 利用教师模型的输出概率(基于logits的方法) 该类方法损失函数为: DIST Tao Huang,Shan You,Fei Wang,
阅读更多...
知识蒸馏dist和KLDiv代码和使用方法
DIST 实现方法 import torch.nn as nndef cosine_similarity(a, b, eps=1e-8):return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)def pearson_correlation(a, b, eps=1e-8):return cosine_similarity(a -
阅读更多...