Pytorch中交叉熵Loss趣解

2024-05-24 20:32
文章标签 pytorch loss 交叉 趣解

本文主要是介绍Pytorch中交叉熵Loss趣解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

点击上方「蓝字」关注我们

背景

最近一直在总结Pytorch中Loss的各种用法,交叉熵是深度学习中最常用的计算方法,写这个稿子把交叉熵的来龙去脉做一个总结。

什么是交叉熵

信息量

引用百度百科中信息量的例子来看,

在日常生活中,极少发生的事件一旦发生是容易引起人们关注的,而司空见惯的事不会引起注意,也就是说,极少见的事件所带来的信息量多。如果用统计学的术语来描述,就是出现概率小的事件信息量多。因此,事件出现得概率越小,信息量愈大。即信息量的多少是与事件发生频繁(即概率大小)成反比。

故越小概率的事情发生的事件本身具有的信息量就越大。例如在去年夏天,小卡拒了湖人投奔快船还捎带打劫雷霆了一个泡椒,这种闷声大发财的事情就有很大的信息量。

信息量的计算公式为:

信息熵

理解了信息量之后,信息熵的理解也就不再困难了。熵原本是热力学中的一个概念,是用来衡量混乱程度的物理量。信息熵则是借用热力学的概念,衡量在事件发生前对于产生信息量的期望。即信息量是确定的具体事件发生后的信息的度量,信息熵是事件发生前预估的期望。

信息熵的计算公式为:

可以看到信息熵是一个求和的函数,是求得信息量的期望。还是以小卡为例,小卡转会前,假设去湖人的概率是0.4,去其他30支球队的概率分别为 (计算方便),猛龙概率为0(心疼...),那么小卡转会的信息熵为

所以小卡转会这个事件预计的信息量为,但是实际小卡去了快船,实际的信息量为 。因为这是一个非常轰动的事件,所以实际的信息量大于了估计所得的期望。

KL散度与交叉熵

理解了信息量和信息熵之后,接下来就是交叉熵的概念了。介绍交叉熵之前,Loss是绕不开的。Loss的通俗解释就是预测值和真实值的差异,然后有各种各样的方法来衡量这个差异有多大,本文所介绍的交叉熵也是一种衡量Loss的方法。

KL散度

在讲交叉熵之前,有一个类似的东西叫KL散度,KL散度是用来衡量两个分布之间差异的指标,计算公式为

公式里面 是真实值, 是预测值,如果 相同时 ,即两者之间没有差异。

交叉熵

现在我们将KL散度的公式进行变形,

其中 是真实值的信息熵,第二项

就是多分类的交叉熵。因此KL散度也被成为相对熵。对于二分类而言,交叉熵为

二分类交叉熵

Pytorch总共提供了两种二分类交叉熵,一种是nn.BCELoss,另一种是nn.BCEWithLogitsLoss,这两个的差别非常细微,nn.BCEWithLogitsLoss=nn.Sigmoid+nn.BCELoss。这里结合Pytorch的代码做一下验证,首先先验证nn.BCELoss

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
# tensor([1.5051, 2.5170, 0.7961], requires_grad=True)
target = torch.empty(3).random_(2)
# tensor([1., 1., 1.])
m(input)
# tensor([0.8183, 0.9253, 0.6891], grad_fn=<SigmoidBackward>)
output = loss(m(input), target)
# tensor(0.2168, grad_fn=<BinaryCrossEntropyBackward>)

对于nn.BCEWithLogitsLoss而言,使用的代码为

# tensor([1.5051, 2.5170, 0.7961], requires_grad=True)
loss1 = nn.BCEWithLogitsLoss()
output1 = loss1(input, target)
# tensor(0.2168, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

可以看到两者的输入完全相同,输出nn.BCEWithLogitsLoss完全等于nn.BCELoss加上nn.Sigmoid

多分类交叉熵

对于多分类交叉熵函数而言,一般使用nn.CrossEntropyLoss,该函数的计算流程为:

  1. 在输入值上施加nn.Softmax函数

  2. 对于第一步所得结果使用log函数,将较为耗时的乘法运算改为加法运算,并将其归一化到 之间

  3. 将第二步所得输出输入nn.NLLLoss函数中,nn.NNLLLoss的作用就是接受负对数似然值,然后对其求平均。

具体的案例在下一节的CIFAR-10的分类问题中。

实际应用

分类问题

这里我们使用Pytorch自带的CIFAR-10的数据集进行分类,训练的网络为

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

网络结构如下图所示:

使用的loss为nn.CrossEntropyLoss,使用的优化器为SGD优化器,batch size为4,分类的图片类别为10类。

网络的输入为:

网络的输出为

tensor([[ 0.1788,  2.7710, -1.7293, -0.7374, -0.2494, -1.8283,  3.9494, -2.5115,-3.8941,  3.0328],[-2.4240, -2.4709,  2.3784,  2.9654,  1.3043,  2.0047,  0.7166,  1.0876,-1.9398, -1.9151],[ 2.8583, -2.6383,  1.5032,  0.3617,  1.3289, -1.2958,  0.0279, -2.3972,2.3778, -2.0852],[-2.8961, -2.9959,  2.6880,  2.5834,  2.5803,  2.3840,  2.4317,  0.9671,-3.0246, -2.5225]], grad_fn=<AddmmBackward>)

输出是一个4×10的矩阵,对应的label为,

label = tensor([9, 5, 8, 6]) #truck, dog, ship, frog

将网络直接的输出输入到nn.softmax可得,并验证加和结果为

softmax_func=nn.Softmax(dim=1)# tensor([[1.3065e-02, 1.7453e-01, 1.9384e-03, 5.2261e-03, 8.5139e-03, 1.7555e-03,
#         5.6710e-01, 8.8657e-04, 2.2246e-04, 2.2676e-01],
#         [1.8934e-03, 1.8066e-03, 2.3061e-01, 4.1479e-01, 7.8778e-02, 1.5870e-01,
#          4.3770e-02, 6.3433e-02, 3.0727e-03, 3.1494e-03],
#         [4.4119e-01, 1.8093e-03, 1.1379e-01, 3.6341e-02, 9.5598e-02, 6.9270e-03,
#          2.6028e-02, 2.3025e-03, 2.7287e-01, 3.1457e-03],
#         [8.3395e-04, 7.5479e-04, 2.2197e-01, 1.9993e-01, 1.9930e-01, 1.6378e-01,
#          1.7178e-01, 3.9713e-02, 7.3344e-04, 1.2118e-03]],
#       grad_fn=<SoftmaxBackward>)softmax_func(outputs).sum(1)
# tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

接下来,将nn.softmax的输出输入torch.log可以得到

log_outputs=torch.log(soft_output)
# tensor([[-4.3379, -1.7457, -6.2459, -5.2541, -4.7661, -6.3450, -0.5672, -7.0281,
#         -8.4108, -1.4839],
#        [-6.2694, -6.3163, -1.4670, -0.8800, -2.5411, -1.8407, -3.1288, -2.7578,
#         -5.7852, -5.7605],
#        [-0.8183, -6.3148, -2.1734, -3.3148, -2.3476, -4.9723, -3.6486, -6.0738,
#         -1.2988, -5.7617],
#        [-7.0893, -7.1891, -1.5052, -1.6098, -1.6130, -1.8093, -1.7616, -3.2261,
#         -7.2178, -6.7157]], grad_fn=<LogBackward>)

最后将log_outputs通过nn.NLLLoss并与nn.CrossEntropy对比

nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(log_outputs,labels)
# tensor(1.5962, grad_fn=<NllLossBackward>)
criterion = nn.CrossEntropyLoss()
criterion(outputs, labels)
# tensor(1.5962, grad_fn=<NllLossBackward>)

可以发现经过组合的loss和直接用nn.CrossEntroyLoss得到的loss是一样的。

下一次推送我们将会解析一下Kaiming大神的Focal Loss。

参考文献

[1] https://gombru.github.io/2018/05/23/cross_entropy_loss/

[2] https://www.baidu.com/link?url=mm7cnRyOERSRY_TPjZ8WbzU3im5Hq1JstcfLngNj4y0P5H4gC9lAhGLWnTBAgoucSnBu-Ek_fwM-RuyWSOfPxv4Idbxr0hm-udxOVd3Yz4rFgPymoQpsOb8_UsSmub-I&wd=&eqid=b0181d0f0002cd51000000045eccec89

[3] https://pytorch.org/docs/master/generated/torch.nn.BCELoss.html

[4] https://pytorch.org/docs/master/generated/torch.nn.NLLLoss.html

我是元峰,互联网+AI领域的创业者,欢迎扫描下方二维码,或者直接在微信搜索“AIZOO”关注我们的公众号AIZOO。您也可以访问我们的网站 AIZOO.com 了解我们。

如果您是有算法需求,例如目标检测、人脸识别、缺陷检测、行人检测的算法需求,欢迎添加我们的微信号AIZOOTech与我们交流,我们团队是一群算法工程师的创业团队,会以高效、稳定、高性价比的产品满足您的需求。

如果您是算法或者开发工程师,也可以添加我们的微信号AIZOOTech,请备注学校or公司名称-研究方向-昵称,例如“浙大-图像算法-元峰”,元峰会拉您进我们的算法交流群,一起交流算法和开发的知识,以及对接项目。

小助手微信号“AIZOOTech”

添加作者元峰微信,邀您进AIZOO技术交流群

让我知道你在看

这篇关于Pytorch中交叉熵Loss趣解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/999453

相关文章

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

MySQL中的交叉连接、自然连接和内连接查询详解

《MySQL中的交叉连接、自然连接和内连接查询详解》:本文主要介绍MySQL中的交叉连接、自然连接和内连接查询,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、引入二、交php叉连接(cross join)三、自然连接(naturalandroid join)四

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确

pytorch+torchvision+python版本对应及环境安装

《pytorch+torchvision+python版本对应及环境安装》本文主要介绍了pytorch+torchvision+python版本对应及环境安装,安装过程中需要注意Numpy版本的降级,... 目录一、版本对应二、安装命令(pip)1. 版本2. 安装全过程3. 命令相关解释参考文章一、版本对

从零教你安装pytorch并在pycharm中使用

《从零教你安装pytorch并在pycharm中使用》本文详细介绍了如何使用Anaconda包管理工具创建虚拟环境,并安装CUDA加速平台和PyTorch库,同时在PyCharm中配置和使用PyTor... 目录背景介绍安装Anaconda安装CUDA安装pytorch报错解决——fbgemm.dll连接p

pycharm远程连接服务器运行pytorch的过程详解

《pycharm远程连接服务器运行pytorch的过程详解》:本文主要介绍在Linux环境下使用Anaconda管理不同版本的Python环境,并通过PyCharm远程连接服务器来运行PyTorc... 目录linux部署pytorch背景介绍Anaconda安装Linux安装pytorch虚拟环境安装cu

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 (debug笔记)

Nn criterions don’t compute the gradient w.r.t. targets error「pytorch」 ##一、 缘由及解决方法 把这个pytorch-ddpg|github搬到jupyter notebook上运行时,出现错误Nn criterions don’t compute the gradient w.r.t. targets error。注:我用