本文主要是介绍长尾问题之LDAM,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
做法&代码&公式
step1: 全连接层的权重W和特征向量X都归一化,相乘 W * X = P (得到各个类别的概率)
# 定义权重,初始化
weight = nn.Parameter(torch.FloatTensor(num_classes, num_features))
weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)# 归一化W,X ; W * X = P
Z = F.linear(F.normalize(x), F.normalize(self.weight))
step2: 损失参数计算及设置
计算 margin
其中nj是对应类别j的数量,C是一个超参常数。代码里C是最多样本类别的0.5倍
以此类推,类别1,类别2 的分别margin
,
代码部分
s = 30 # 设置缩放系数 s=30
num_class_list=[13000,450,4231,8000 ... ] # 各个类别样本数量max_m = 0.5
m_list = 1.0 / np.sqrt(np.sqrt(num_class_list))
m_list = m_list * (max_m / np.max(m_list)) # ignore是0影响计算
re-weight
epoch∈(0~160)时为
betas = [0, 0.9999]
re_weights = (1.0 - betas[0]) / (1.0 - np.power(betas[0], num_class_list))
re_weights = re_weights / np.sum(per_cls_weights) * num_class)
epoch∈(160~maxepoch)时为:
betas = [0, 0.9999]
re_weights = (1.0 - betas[1]) / (1.0 - np.power(betas[1], num_class_list))
re_weights = re_weights / np.sum(re_weights) * num_class)
step4:计算损失
# index: 哪些位置是y_true
x_m = Z - m_list
outputs = torch.where(index, x_m, Z)
F.cross_entropy(s * outputs, targets, weight=re_weight)
参考:https://github.com/zhangyongshun/BagofTricks-LT/blob/main/documents/trick_gallery.md
MMPretrain实现 不要积分的,免费的
https://download.csdn.net/download/magic_shuang/88632302
mmpretrain/models/heads/ldam_head.py
mmpretrain/models/losses/ldam_loss.py
这篇关于长尾问题之LDAM的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!