本文主要是介绍关于Balanced-MixUp是自定义的交叉熵损失函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
Balanced-MixUp的自定义的交叉熵损失函数
def cross_entropy_loss(input: torch.Tensor,target: torch.Tensor) -> torch.Tensor:return -(input.log_softmax(dim=-1) * target).sum(dim=-1).mean()
官方的 nn.CrossEntropyLoss()
import torch.nn as nncriterion = nn.CrossEntropyLoss()
主要区别
-
输入类型和形状:
- 自定义函数:假设
input
是一个具有 logits 的张量,形状为(batch_size, num_classes)
,而target
是独热编码的标签,形状也是(batch_size, num_classes)
。 - 官方函数:假设
input
是一个具有 logits 的张量,形状为(batch_size, num_classes)
,而target
是一个长整型张量,形状为(batch_size)
,每个值表示对应样本的类索引。
- 自定义函数:假设
mixup的相关操作是怎样将类索引标签变成独热编码的
mixed_labels = (1 - lam) * F.one_hot(labels, n_classes) + lam * F.one_hot(balanced_labels, n_classes)
这篇关于关于Balanced-MixUp是自定义的交叉熵损失函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!