【损失函数】Quantile Loss 分位数损失

2024-01-04 07:52

本文主要是介绍【损失函数】Quantile Loss 分位数损失,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1、介绍

Quantile Loss(分位数损失)是用于回归问题的一种损失函数,它允许我们对不同分位数的预测误差赋予不同的权重。这对于处理不同置信水平的预测非常有用,例如在风险管理等领域。

当我们需要对区间预测而不单是点预测时 分位数损失函数可以发挥很大作用

2、公式

$J_{\text {quant }}=\frac{1}{N} \sum_{i=1}^N \mathbb{I}_{\hat{y}_i \geq y_i}(1-\gamma)\left|y_i-\hat{y}_i\right|+\mathbb{I}_{\hat{y}_i<y_i} \gamma\left|y_i-\hat{y}_i\right|$

其中,{y}_i是实际目标值,\hat{y}_i 是模型的预测值,\gamma 是分位数水平,通常取值在 0 和 1 之间。

        我们如何理解这个损失函数呢?这个损失函数是一个分段的函数 ,将  \hat{y}_i \geq y_i(高估) 和  \hat{y}_i<y_i(低估) 两种情况分开来,并分别给予不同的系数。当 \gamma > 0.5 时,低估的损失要比高估的损失更大,反过来当 \gamma < 0.5 时,高估的损失比低估的损失大;分位数损失实现了分别用不同的系数控制高估和低估的损失,进而实现分位数回归。特别地,当 \gamma = 0.5 时,分位数损失退化为 MAE 损失,从这里可以看出 MAE 损失实际上是分位数损失的一个特例 — 中位数回归(这也可以解释为什么 MAE 损失对 outlier 更鲁棒:MSE 回归期望值,MAE 回归中位数,通常 outlier 对中位数的影响比对期望值的影响小)。      

        简单的总结下,分位数损失通过 \gamma 的不同取值来避免过拟合和欠拟合,实现分位数回归。

        分位数值的选择基于在实际中需要误差如何发挥作用,即在过程中误差为正时发挥更多作用还是在误差为负时发挥更大作用。

3、图像

        上图是分位数损失(Quantile Loss)在分位数为 0.3、0.5、0.7 时的图像。图中显示了预测值(f)与分位数损失之间的关系,可以看到 0.3 和 0.8 在高估和低估两种情况下损失是不同的,而 0.5 实际上就是 MAE。

4、实例

假设我们有以下情况:我们正在训练一个模型来预测房价涨幅区间。我们有以下目标值(真实值)和预测值:

  • 目标(真实值): [2.0, 1.0, 4.0, 3.5, 5.0]
  • 预测: [1.8, 0.9, 3.5, 3.0, 4.8]

我们使用 Quantile Loss作为损失函数:

import torch
import torch.nn as nnclass QuantileLoss(nn.Module):def __init__(self, quantile):super(QuantileLoss, self).__init__()self.quantile = quantiledef forward(self, y, y_pred):residual = y_pred - yloss = torch.max((self.quantile - 1) * residual, self.quantile * residual)return torch.mean(loss)
# 示例数据
y_true = torch.tensor([2.0, 1.0, 4.0, 3.5, 5.0], dtype=torch.float32)
y_pred = torch.tensor([1.8, 0.9, 3.5, 3.0, 4.8], dtype=torch.float32)
# 定义分位数水平 当分位数为 0.5 时,分位数损失退化为 MAE 损失
quantile = 0.5
# 初始化损失函数
quantile_loss = QuantileLoss(quantile)
# 计算损失
loss = quantile_loss(y_true, y_pred)
# Quantile Loss: 0.14999999105930328
print(f'Quantile Loss: {loss.item()}')

       在上述示例中,我们使用了一个简单的自定义 PyTorch 模块 `QuantileLoss`,它采用分位数水平作为参数,并计算相应的 Quantile Loss。这个例子中使用的分位数是 0.5,即中位数。此时分位数损失退化为 MAE 损失,实际应用中根据不同需求设定不同的分位数水平。

5、参考

损失函数 Loss Function 之 分位数损失 Quantile Loss - 知乎 (zhihu.com)

深度学习常用损失函数总览:基本形式、原理、特点 (qq.com)

这篇关于【损失函数】Quantile Loss 分位数损失的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/568649

相关文章

hdu1171(母函数或多重背包)

题意:把物品分成两份,使得价值最接近 可以用背包,或者是母函数来解,母函数(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v)(1 + x^v+x^2v+.....+x^num*v) 其中指数为价值,每一项的数目为(该物品数+1)个 代码如下: #include<iostream>#include<algorithm>

C++操作符重载实例(独立函数)

C++操作符重载实例,我们把坐标值CVector的加法进行重载,计算c3=c1+c2时,也就是计算x3=x1+x2,y3=y1+y2,今天我们以独立函数的方式重载操作符+(加号),以下是C++代码: c1802.cpp源代码: D:\YcjWork\CppTour>vim c1802.cpp #include <iostream>using namespace std;/*** 以独立函数

函数式编程思想

我们经常会用到各种各样的编程思想,例如面向过程、面向对象。不过笔者在该博客简单介绍一下函数式编程思想. 如果对函数式编程思想进行概括,就是f(x) = na(x) , y=uf(x)…至于其他的编程思想,可能是y=a(x)+b(x)+c(x)…,也有可能是y=f(x)=f(x)/a + f(x)/b+f(x)/c… 面向过程的指令式编程 面向过程,简单理解就是y=a(x)+b(x)+c(x)

利用matlab bar函数绘制较为复杂的柱状图,并在图中进行适当标注

示例代码和结果如下:小疑问:如何自动选择合适的坐标位置对柱状图的数值大小进行标注?😂 clear; close all;x = 1:3;aa=[28.6321521955954 26.2453660695847 21.69102348512086.93747104431360 6.25442246899816 3.342835958564245.51365061796319 4.87

OpenCV结构分析与形状描述符(11)椭圆拟合函数fitEllipse()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C++11 算法描述 围绕一组2D点拟合一个椭圆。 该函数计算出一个椭圆,该椭圆在最小二乘意义上最好地拟合一组2D点。它返回一个内切椭圆的旋转矩形。使用了由[90]描述的第一个算法。开发者应该注意,由于数据点靠近包含的 Mat 元素的边界,返回的椭圆/旋转矩形数据

Unity3D 运动之Move函数和translate

CharacterController.Move 移动 function Move (motion : Vector3) : CollisionFlags Description描述 A more complex move function taking absolute movement deltas. 一个更加复杂的运动函数,每次都绝对运动。 Attempts to

SigLIP——采用sigmoid损失的图文预训练方式

SigLIP——采用sigmoid损失的图文预训练方式 FesianXu 20240825 at Wechat Search Team 前言 CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注

✨机器学习笔记(二)—— 线性回归、代价函数、梯度下降

1️⃣线性回归(linear regression) f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b​(x)=wx+b 🎈A linear regression model predicting house prices: 如图是机器学习通过监督学习运用线性回归模型来预测房价的例子,当房屋大小为1250 f e e t 2 feet^

JavaSE(十三)——函数式编程(Lambda表达式、方法引用、Stream流)

函数式编程 函数式编程 是 Java 8 引入的一个重要特性,它允许开发者以函数作为一等公民(first-class citizens)的方式编程,即函数可以作为参数传递给其他函数,也可以作为返回值。 这极大地提高了代码的可读性、可维护性和复用性。函数式编程的核心概念包括高阶函数、Lambda 表达式、函数式接口、流(Streams)和 Optional 类等。 函数式编程的核心是Lambda

PHP APC缓存函数使用教程

APC,全称是Alternative PHP Cache,官方翻译叫”可选PHP缓存”。它为我们提供了缓存和优化PHP的中间代码的框架。 APC的缓存分两部分:系统缓存和用户数据缓存。(Linux APC扩展安装) 系统缓存 它是指APC把PHP文件源码的编译结果缓存起来,然后在每次调用时先对比时间标记。如果未过期,则使用缓存的中间代码运行。默认缓存 3600s(一小时)。但是这样仍会浪费大量C