本文主要是介绍PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
前言
在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed
。
错误信息类似下面打印信息:
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):File "tutorial.py", line 100, in <module>model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)File "tutorial.py", line 80, in train_modelloss = criterion(outputs, labels)File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__result = self.forward(*input, **kwargs)File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forwardself.weight, self.size_average)File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropyreturn nll_loss(log_softmax(input), target, weight, size_average)File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_lossreturn f(input, target)File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forwardoutput, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83
通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes
条件。
t >= 0 && t < n_classes条件
在分类任务中,需要调用torch.nn.functional.cross_entropy()
函数求交叉熵,从PyTorch官网可以看到该函数定义:
torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
可以注意到有一个key-value是ignore_index=-100。这是在交叉熵计算时被跳过的部分。通常是在数据增强中的填充值。
而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed
,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()
方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。
代码示例
数据增强部分
import torchvision.transforms.functional as tftf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)
求交叉熵部分
import torch
import torch.nn.functional as F
import torch.nn as nndef cross_entropy2d(input, target, weight=None, reduction='none'):n, c, h, w = input.size()nt, ht, wt = target.size()if h != ht or w != wt:input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)target = target.view(-1)loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)return loss
分析
可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed
。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。
其他
在PyTorch框架下,使用无用label值进行填充和处理时,要注意在使用scatter_
函数时也需要注意对无用label进行提前处理,否则在使用data.scatter_()
时同样也会报类似类别index错误。
labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)
参考资料
[1] torch.nn.functional — PyTorch 1.8.0 documentation
[2] Pytorch里的CrossEntropyLoss详解 - marsggbo - 博客园
[3] RuntimeError: cuda runtime error (59) : device-side assert triggered when running transfer_learning_tutorial · Issue #1204 · pytorch/pytorch
[4] PyTorch 中,nn 与 nn.functional 有什么区别? - 知乎
[5] FaceParsing.PyTorch/augmentations.py at master · TracelessLe/FaceParsing.PyTorch
这篇关于PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!