损失函数总结(四):NLLLoss、CTCLoss

2023-10-25 17:04

本文主要是介绍损失函数总结(四):NLLLoss、CTCLoss,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

损失函数总结(四):NLLLoss、CTCLoss

  • 1 引言
  • 2 损失函数
    • 2.1 NLLLoss
    • 2.2 CTCLoss
  • 3 总结

1 引言

在前面的文章中已经介绍了介绍了一系列损失函数 (L1LossMSELossBCELossCrossEntropyLoss)。在这篇文章中,会接着上文提到的众多损失函数继续进行介绍,给大家带来更多不常见的损失函数的介绍。这里放一张损失函数的机理图:
在这里插入图片描述

2 损失函数

2.1 NLLLoss

NLLLoss(Negative Log Likelihood Loss,负对数似然损失)通常用于训练分类模型,尤其是在多类别分类任务中。它是一种用于度量模型的类别概率分布实际类别分布之间的差距的损失函数。NLLLoss 的数学表达式如下:
L NLL ( Y , Y ′ ) = − 1 n ∑ i = 1 n ∑ j = 1 C y i j log ⁡ ( y i j ′ ) L_{\text{NLL}}(Y, Y') = -\frac{1}{n} \sum_{i=1}^{n} \sum_{j=1}^{C} y_{ij} \log(y_{ij}') LNLL(Y,Y)=n1i=1nj=1Cyijlog(yij)

其中:

  • L CE ( Y , Y ′ ) L_{\text{CE}}(Y, Y') LCE(Y,Y) 是整个数据集上的交叉熵损失
  • n n n 是样本数量。
  • C C C 是类别数量。
  • y i j y_{ij} yij 是第 i i i 个样本的实际类别分布,通常是一个独热编码(one-hot encoding)向量,表示实际类别
  • y i j ′ y_{ij}' yij 是第 i i i 个样本的模型预测的类别概率分布,通常是一个概率向量,表示模型对每个类别的预测概率

注意:上面的公式和 CrossEntropyLoss 公式相同,但实际上是不同的。实际关系为:
NLLLoss + LogSoftmax = CrossEntropyLoss

代码实现(Pytorch):

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)
output.backward()
# 2D loss example (used, for example, with image inputs)
N, C = 5, 4
loss = nn.NLLLoss()
# input is of size N x C x height x width
data = torch.randn(N, 16, 10, 10)
conv = nn.Conv2d(16, C, (3, 3))
m = nn.LogSoftmax(dim=1)
# each element in target has to have 0 <= value < C
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
output = loss(m(conv(data)), target)
output.backward()

NLLLoss 通常用于分类任务,特别是当模型输出的是类别概率分布时。NLLLoss 和 CrossEntropyLoss 是等价的,可以相互替换。。。

2.2 CTCLoss

论文链接:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

CTC Loss(Connectionist Temporal Classification Loss,连接时序分类损失)通常用于训练序列到序列(sequence-to-sequence)模型,尤其是在语音识别自然语言处理中的任务,其中输出序列的长度与输入序列的长度不一致。CTC Loss 的主要目标是将模型的输出与目标序列对齐,以度量它们之间的相似度。CTCLoss 的数学表达式如下:
L CTC ( S ) = − ln ⁡ ∑ ( x , z ) ∈ S p ( z ∣ x ) = − ∑ ( x , z ) ∈ S l n p ( z ∣ x ) L_{\text{CTC}}(S) = -\ln \sum_{(x,z) \in S} p(z|x) = -\sum_{(x,z) \in S} lnp(z|x) LCTC(S)=ln(x,z)Sp(zx)=(x,z)Slnp(zx)

其中:

  • S S S 表示训练集
  • L CTC ( S ) L_{\text{CTC}}(S) LCTC(S) 表示 给定标签序列和输入,最终输出正确序列的概率

代码实现(Pytorch):

# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded and unbatched (effectively N=1)
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
# Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.tensor(T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

CTCLoss 在语音识别自然语言处理中具有广泛的应用,可以广泛用于sequence-to-sequence任务。

3 总结

到此,使用 损失函数总结(四) 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。如果存在没有提及的损失函数也可以在评论区提出,后续会对其进行添加!!!!

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

这篇关于损失函数总结(四):NLLLoss、CTCLoss的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/283968

相关文章

Python itertools中accumulate函数用法及使用运用详细讲解

《Pythonitertools中accumulate函数用法及使用运用详细讲解》:本文主要介绍Python的itertools库中的accumulate函数,该函数可以计算累积和或通过指定函数... 目录1.1前言:1.2定义:1.3衍生用法:1.3Leetcode的实际运用:总结 1.1前言:本文将详

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意

Python中连接不同数据库的方法总结

《Python中连接不同数据库的方法总结》在数据驱动的现代应用开发中,Python凭借其丰富的库和强大的生态系统,成为连接各种数据库的理想编程语言,下面我们就来看看如何使用Python实现连接常用的几... 目录一、连接mysql数据库二、连接PostgreSQL数据库三、连接SQLite数据库四、连接Mo

Git提交代码详细流程及问题总结

《Git提交代码详细流程及问题总结》:本文主要介绍Git的三大分区,分别是工作区、暂存区和版本库,并详细描述了提交、推送、拉取代码和合并分支的流程,文中通过代码介绍的非常详解,需要的朋友可以参考下... 目录1.git 三大分区2.Git提交、推送、拉取代码、合并分支详细流程3.问题总结4.git push

Java function函数式接口的使用方法与实例

《Javafunction函数式接口的使用方法与实例》:本文主要介绍Javafunction函数式接口的使用方法与实例,函数式接口如一支未完成的诗篇,用Lambda表达式作韵脚,将代码的机械美感... 目录引言-当代码遇见诗性一、函数式接口的生物学解构1.1 函数式接口的基因密码1.2 六大核心接口的形态学

Kubernetes常用命令大全近期总结

《Kubernetes常用命令大全近期总结》Kubernetes是用于大规模部署和管理这些容器的开源软件-在希腊语中,这个词还有“舵手”或“飞行员”的意思,使用Kubernetes(有时被称为“... 目录前言Kubernetes 的工作原理为什么要使用 Kubernetes?Kubernetes常用命令总

Python中实现进度条的多种方法总结

《Python中实现进度条的多种方法总结》在Python编程中,进度条是一个非常有用的功能,它能让用户直观地了解任务的进度,提升用户体验,本文将介绍几种在Python中实现进度条的常用方法,并通过代码... 目录一、简单的打印方式二、使用tqdm库三、使用alive-progress库四、使用progres

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日

Android数据库Room的实际使用过程总结

《Android数据库Room的实际使用过程总结》这篇文章主要给大家介绍了关于Android数据库Room的实际使用过程,详细介绍了如何创建实体类、数据访问对象(DAO)和数据库抽象类,需要的朋友可以... 目录前言一、Room的基本使用1.项目配置2.创建实体类(Entity)3.创建数据访问对象(DAO