PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录

本文主要是介绍PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在这里插入图片描述

在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed

错误信息类似下面打印信息:

/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):File "tutorial.py", line 100, in <module>model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)File "tutorial.py", line 80, in train_modelloss = criterion(outputs, labels)File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__result = self.forward(*input, **kwargs)File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forwardself.weight, self.size_average)File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropyreturn nll_loss(log_softmax(input), target, weight, size_average)File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_lossreturn f(input, target)File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forwardoutput, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83

通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes条件。

t >= 0 && t < n_classes条件

在分类任务中,需要调用torch.nn.functional.cross_entropy()函数求交叉熵,从PyTorch官网可以看到该函数定义:
在这里插入图片描述

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

可以注意到有一个key-value是ignore_index=-100。这是在交叉熵计算时被跳过的部分。通常是在数据增强中的填充值。

而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。

在这里插入图片描述

代码示例

数据增强部分

import torchvision.transforms.functional as tftf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)

求交叉熵部分

import torch
import torch.nn.functional as F
import torch.nn as nndef cross_entropy2d(input, target, weight=None, reduction='none'):n, c, h, w = input.size()nt, ht, wt = target.size()if h != ht or w != wt:input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)target = target.view(-1)loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)return loss

分析

可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。

其他

在PyTorch框架下,使用无用label值进行填充和处理时,要注意在使用scatter_函数时也需要注意对无用label进行提前处理,否则在使用data.scatter_()时同样也会报类似类别index错误。

labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)

在这里插入图片描述

参考资料

[1] torch.nn.functional — PyTorch 1.8.0 documentation
[2] Pytorch里的CrossEntropyLoss详解 - marsggbo - 博客园
[3] RuntimeError: cuda runtime error (59) : device-side assert triggered when running transfer_learning_tutorial · Issue #1204 · pytorch/pytorch
[4] PyTorch 中,nn 与 nn.functional 有什么区别? - 知乎
[5] FaceParsing.PyTorch/augmentations.py at master · TracelessLe/FaceParsing.PyTorch

这篇关于PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/922145

相关文章

线上Java OOM问题定位与解决方案超详细解析

《线上JavaOOM问题定位与解决方案超详细解析》OOM是JVM抛出的错误,表示内存分配失败,:本文主要介绍线上JavaOOM问题定位与解决方案的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录一、OOM问题核心认知1.1 OOM定义与技术定位1.2 OOM常见类型及技术特征二、OOM问题定位工具

Python使用FastAPI实现大文件分片上传与断点续传功能

《Python使用FastAPI实现大文件分片上传与断点续传功能》大文件直传常遇到超时、网络抖动失败、失败后只能重传的问题,分片上传+断点续传可以把大文件拆成若干小块逐个上传,并在中断后从已完成分片继... 目录一、接口设计二、服务端实现(FastAPI)2.1 运行环境2.2 目录结构建议2.3 serv

Spring Security简介、使用与最佳实践

《SpringSecurity简介、使用与最佳实践》SpringSecurity是一个能够为基于Spring的企业应用系统提供声明式的安全访问控制解决方案的安全框架,本文给大家介绍SpringSec... 目录一、如何理解 Spring Security?—— 核心思想二、如何在 Java 项目中使用?——

springboot中使用okhttp3的小结

《springboot中使用okhttp3的小结》OkHttp3是一个JavaHTTP客户端,可以处理各种请求类型,比如GET、POST、PUT等,并且支持高效的HTTP连接池、请求和响应缓存、以及异... 在 Spring Boot 项目中使用 OkHttp3 进行 HTTP 请求是一个高效且流行的方式。

Java使用Javassist动态生成HelloWorld类

《Java使用Javassist动态生成HelloWorld类》Javassist是一个非常强大的字节码操作和定义库,它允许开发者在运行时创建新的类或者修改现有的类,本文将简单介绍如何使用Javass... 目录1. Javassist简介2. 环境准备3. 动态生成HelloWorld类3.1 创建CtC

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Java使用jar命令配置服务器端口的完整指南

《Java使用jar命令配置服务器端口的完整指南》本文将详细介绍如何使用java-jar命令启动应用,并重点讲解如何配置服务器端口,同时提供一个实用的Web工具来简化这一过程,希望对大家有所帮助... 目录1. Java Jar文件简介1.1 什么是Jar文件1.2 创建可执行Jar文件2. 使用java

C#使用Spire.Doc for .NET实现HTML转Word的高效方案

《C#使用Spire.Docfor.NET实现HTML转Word的高效方案》在Web开发中,HTML内容的生成与处理是高频需求,然而,当用户需要将HTML页面或动态生成的HTML字符串转换为Wor... 目录引言一、html转Word的典型场景与挑战二、用 Spire.Doc 实现 HTML 转 Word1

Vue3绑定props默认值问题

《Vue3绑定props默认值问题》使用Vue3的defineProps配合TypeScript的interface定义props类型,并通过withDefaults设置默认值,使组件能安全访问传入的... 目录前言步骤步骤1:使用 defineProps 定义 Props步骤2:设置默认值总结前言使用T

Java中的抽象类与abstract 关键字使用详解

《Java中的抽象类与abstract关键字使用详解》:本文主要介绍Java中的抽象类与abstract关键字使用详解,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友跟随小编一起看看吧... 目录一、抽象类的概念二、使用 abstract2.1 修饰类 => 抽象类2.2 修饰方法 => 抽象方法,没有