损失函数总结(四):NLLLoss、CTCLoss

2023-10-25 17:04

本文主要是介绍损失函数总结(四):NLLLoss、CTCLoss,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

损失函数总结(四):NLLLoss、CTCLoss

  • 1 引言
  • 2 损失函数
    • 2.1 NLLLoss
    • 2.2 CTCLoss
  • 3 总结

1 引言

在前面的文章中已经介绍了介绍了一系列损失函数 (L1LossMSELossBCELossCrossEntropyLoss)。在这篇文章中,会接着上文提到的众多损失函数继续进行介绍,给大家带来更多不常见的损失函数的介绍。这里放一张损失函数的机理图:
在这里插入图片描述

2 损失函数

2.1 NLLLoss

NLLLoss(Negative Log Likelihood Loss,负对数似然损失)通常用于训练分类模型,尤其是在多类别分类任务中。它是一种用于度量模型的类别概率分布实际类别分布之间的差距的损失函数。NLLLoss 的数学表达式如下:
L NLL ( Y , Y ′ ) = − 1 n ∑ i = 1 n ∑ j = 1 C y i j log ⁡ ( y i j ′ ) L_{\text{NLL}}(Y, Y') = -\frac{1}{n} \sum_{i=1}^{n} \sum_{j=1}^{C} y_{ij} \log(y_{ij}') LNLL(Y,Y)=n1i=1nj=1Cyijlog(yij)

其中:

  • L CE ( Y , Y ′ ) L_{\text{CE}}(Y, Y') LCE(Y,Y) 是整个数据集上的交叉熵损失
  • n n n 是样本数量。
  • C C C 是类别数量。
  • y i j y_{ij} yij 是第 i i i 个样本的实际类别分布,通常是一个独热编码(one-hot encoding)向量,表示实际类别
  • y i j ′ y_{ij}' yij 是第 i i i 个样本的模型预测的类别概率分布,通常是一个概率向量,表示模型对每个类别的预测概率

注意:上面的公式和 CrossEntropyLoss 公式相同,但实际上是不同的。实际关系为:
NLLLoss + LogSoftmax = CrossEntropyLoss

代码实现(Pytorch):

m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = loss(m(input), target)
output.backward()
# 2D loss example (used, for example, with image inputs)
N, C = 5, 4
loss = nn.NLLLoss()
# input is of size N x C x height x width
data = torch.randn(N, 16, 10, 10)
conv = nn.Conv2d(16, C, (3, 3))
m = nn.LogSoftmax(dim=1)
# each element in target has to have 0 <= value < C
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
output = loss(m(conv(data)), target)
output.backward()

NLLLoss 通常用于分类任务,特别是当模型输出的是类别概率分布时。NLLLoss 和 CrossEntropyLoss 是等价的,可以相互替换。。。

2.2 CTCLoss

论文链接:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

CTC Loss(Connectionist Temporal Classification Loss,连接时序分类损失)通常用于训练序列到序列(sequence-to-sequence)模型,尤其是在语音识别自然语言处理中的任务,其中输出序列的长度与输入序列的长度不一致。CTC Loss 的主要目标是将模型的输出与目标序列对齐,以度量它们之间的相似度。CTCLoss 的数学表达式如下:
L CTC ( S ) = − ln ⁡ ∑ ( x , z ) ∈ S p ( z ∣ x ) = − ∑ ( x , z ) ∈ S l n p ( z ∣ x ) L_{\text{CTC}}(S) = -\ln \sum_{(x,z) \in S} p(z|x) = -\sum_{(x,z) \in S} lnp(z|x) LCTC(S)=ln(x,z)Sp(zx)=(x,z)Slnp(zx)

其中:

  • S S S 表示训练集
  • L CTC ( S ) L_{\text{CTC}}(S) LCTC(S) 表示 给定标签序列和输入,最终输出正确序列的概率

代码实现(Pytorch):

# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded and unbatched (effectively N=1)
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
# Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.tensor(T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

CTCLoss 在语音识别自然语言处理中具有广泛的应用,可以广泛用于sequence-to-sequence任务。

3 总结

到此,使用 损失函数总结(四) 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。如果存在没有提及的损失函数也可以在评论区提出,后续会对其进行添加!!!!

如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

这篇关于损失函数总结(四):NLLLoss、CTCLoss的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/283968

相关文章

PostgreSQL中rank()窗口函数实用指南与示例

《PostgreSQL中rank()窗口函数实用指南与示例》在数据分析和数据库管理中,经常需要对数据进行排名操作,PostgreSQL提供了强大的窗口函数rank(),可以方便地对结果集中的行进行排名... 目录一、rank()函数简介二、基础示例:部门内员工薪资排名示例数据排名查询三、高级应用示例1. 每

全面掌握 SQL 中的 DATEDIFF函数及用法最佳实践

《全面掌握SQL中的DATEDIFF函数及用法最佳实践》本文解析DATEDIFF在不同数据库中的差异,强调其边界计算原理,探讨应用场景及陷阱,推荐根据需求选择TIMESTAMPDIFF或inte... 目录1. 核心概念:DATEDIFF 究竟在计算什么?2. 主流数据库中的 DATEDIFF 实现2.1

MySQL中的LENGTH()函数用法详解与实例分析

《MySQL中的LENGTH()函数用法详解与实例分析》MySQLLENGTH()函数用于计算字符串的字节长度,区别于CHAR_LENGTH()的字符长度,适用于多字节字符集(如UTF-8)的数据验证... 目录1. LENGTH()函数的基本语法2. LENGTH()函数的返回值2.1 示例1:计算字符串

Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式

《Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式》本文详细介绍如何使用Java通过JDBC连接MySQL数据库,包括下载驱动、配置Eclipse环境、检测数据库连接等关键步骤,... 目录一、下载驱动包二、放jar包三、检测数据库连接JavaJava 如何使用 JDBC 连接 mys

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Python函数作用域示例详解

《Python函数作用域示例详解》本文介绍了Python中的LEGB作用域规则,详细解析了变量查找的四个层级,通过具体代码示例,展示了各层级的变量访问规则和特性,对python函数作用域相关知识感兴趣... 目录一、LEGB 规则二、作用域实例2.1 局部作用域(Local)2.2 闭包作用域(Enclos

JavaSE正则表达式用法总结大全

《JavaSE正则表达式用法总结大全》正则表达式就是由一些特定的字符组成,代表的是一个规则,:本文主要介绍JavaSE正则表达式用法的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考下... 目录常用的正则表达式匹配符正则表China编程达式常用的类Pattern类Matcher类PatternSynta

MySQL count()聚合函数详解

《MySQLcount()聚合函数详解》MySQL中的COUNT()函数,它是SQL中最常用的聚合函数之一,用于计算表中符合特定条件的行数,本文给大家介绍MySQLcount()聚合函数,感兴趣的朋... 目录核心功能语法形式重要特性与行为如何选择使用哪种形式?总结深入剖析一下 mysql 中的 COUNT

MySQL 中 ROW_NUMBER() 函数最佳实践

《MySQL中ROW_NUMBER()函数最佳实践》MySQL中ROW_NUMBER()函数,作为窗口函数为每行分配唯一连续序号,区别于RANK()和DENSE_RANK(),特别适合分页、去重... 目录mysql 中 ROW_NUMBER() 函数详解一、基础语法二、核心特点三、典型应用场景1. 数据分