NeuralForecast VanillaTransformer MAE损失函数

2024-06-05 22:20

本文主要是介绍NeuralForecast VanillaTransformer MAE损失函数,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

NeuralForecast VanillaTransformer MAE损失函数

flyfish

nn.L1Loss() 和 自定义的class MAE(BasePointLoss): 在本质上都是计算 Mean Absolute Error (MAE),但是它们有一些不同之处,主要在于定制化和功能上的差异。
写一个自定义的MAE完整示例代码


import mathfrom typing import Optional, Union, Tupleimport math
import numpy as np
import torchimport torch.nn as nn
import torch.nn.functional as F
def _divide_no_nan(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:"""Auxiliary funtion to handle divide by 0"""div = a / bdiv[div != div] = 0.0div[div == float("inf")] = 0.0return div
def _weighted_mean(losses, weights):"""Compute weighted mean of losses per datapoint."""return _divide_no_nan(torch.sum(losses * weights), torch.sum(weights))
class BasePointLoss(torch.nn.Module):"""Base class for point loss functions.**Parameters:**<br>`horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>`outputsize_multiplier`: Multiplier for the output size. <br>`output_names`: Names of the outputs. <br>"""def __init__(self, horizon_weight, outputsize_multiplier, output_names):super(BasePointLoss, self).__init__()if horizon_weight is not None:horizon_weight = torch.Tensor(horizon_weight.flatten())self.horizon_weight = horizon_weightself.outputsize_multiplier = outputsize_multiplierself.output_names = output_namesself.is_distribution_output = Falsedef domain_map(self, y_hat: torch.Tensor):"""Univariate loss operates in dimension [B,T,H]/[B,H]This changes the network's output from [B,H,1]->[B,H]"""return y_hat.squeeze(-1)def _compute_weights(self, y, mask):"""Compute final weights for each datapoint (based on all weights and all masks)Set horizon_weight to a ones[H] tensor if not set.If set, check that it has the same length as the horizon in x."""if mask is None:mask = torch.ones_like(y, device=y.device)if self.horizon_weight is None:self.horizon_weight = torch.ones(mask.shape[-1])else:assert mask.shape[-1] == len(self.horizon_weight), "horizon_weight must have same length as Y"weights = self.horizon_weight.clone()weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)return weights * maskclass MAE(BasePointLoss):"""Mean Absolute ErrorCalculates Mean Absolute Error between`y` and `y_hat`. MAE measures the relative predictionaccuracy of a forecasting method by calculating thedeviation of the prediction and the truevalue at a given time and averages these devationsover the length of the series.$$ \mathrm{MAE}(\\mathbf{y}_{\\tau}, \\mathbf{\hat{y}}_{\\tau}) = \\frac{1}{H} \\sum^{t+H}_{\\tau=t+1} |y_{\\tau} - \hat{y}_{\\tau}| $$**Parameters:**<br>`horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>"""def __init__(self, horizon_weight=None):super(MAE, self).__init__(horizon_weight=horizon_weight, outputsize_multiplier=1, output_names=[""])def __call__(self,y: torch.Tensor,y_hat: torch.Tensor,mask: Union[torch.Tensor, None] = None,):"""**Parameters:**<br>`y`: tensor, Actual values.<br>`y_hat`: tensor, Predicted values.<br>`mask`: tensor, Specifies datapoints to consider in loss.<br>**Returns:**<br>`mae`: tensor (single value)."""losses = torch.abs(y - y_hat)weights = self._compute_weights(y=y, mask=mask)return _weighted_mean(losses=losses, weights=weights)# 定义简单的线性模型示例
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.linear = nn.Linear(10, 1)  # 10个输入特征,1个输出def forward(self, x):return self.linear(x)# 初始化模型和损失函数
model = SimpleModel()
mae_loss = MAE(horizon_weight=None)# 生成示例数据
# 批次大小为5,时间步长为10,假设预测未来一个时间步的值
batch_size = 5
time_steps = 10
input_features = 10# 随机生成输入数据和真实标签
x = torch.randn(batch_size, input_features)
y = torch.randn(batch_size, 1)  # 真实值# 生成预测值
y_hat = model(x)# 调用 domain_map 函数
y_hat_mapped = mae_loss.domain_map(y_hat)# 调用损失函数
# 这里假设 mask 为 None,表示考虑所有数据点
mae_value = mae_loss(y, y_hat_mapped, mask=None)# 打印MAE值
print("Mean Absolute Error (MAE):", mae_value.item())

M A E = 1 H ∑ i = 1 H ∣ y i − y ^ i ∣ \mathrm{MAE} = \frac{1}{H} \sum_{i=1}^{H} | y_i - \hat{y}_i | MAE=H1i=1Hyiy^i

y i y_i yi表示实际值。
y ^ i \hat{y}_i y^i表示预测值。
H H H表示预测的时间步数或样本数量

这两个函数是用来计算加权平均损失的辅助函数。

_divide_no_nan(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:

这个函数用来处理两个张量相除时出现除以零的情况。它首先计算两个张量相除的结果 div,然后将结果中的 NaN 值(由除以零导致)替换为 0.0,并将结果中的正无穷值替换为 0.0,最后返回处理后的结果。

如果 b 中包含 0,那么 a / b 的计算会产生除以零的情况,这会导致结果中出现 NaN(“Not a Number”)或正无穷大(inf)值。_divide_no_nan 函数的目的是处理这些情况,确保输出结果中没有 NaN 或无穷大值。

让我们详细说明这一过程:

初始计算:
div = a / b 进行逐元素相除,如果 b 中有 0,结果 div 中相应位置会包含 NaN 或 inf。

替换 NaN 值:
div[div != div] = 0.0 这一行代码使用了一个技巧:由于 NaN 不等于任何值,包括它自己,div != div 会在 NaN 所在的位置返回 True。于是,div[div != div] 会选中所有 NaN 并将其设置为 0.0。

替换 inf 值:
div[div == float(“inf”)] = 0.0 这一行代码将所有正无穷大(inf)值替换为 0.0。

因此,如果 b 包含 0,函数 _divide_no_nan 会确保相应位置的结果是 0.0,而不是 NaN 或 inf。这保证了计算的稳定性和结果的可用性。

def _divide_no_nan(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:div = a / bdiv[div != div] = 0.0div[div == float("inf")] = 0.0return diva = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([0.0, 2.0, 0.0])result = _divide_no_nan(a, b)
print(result)  # 输出 tensor([0., 1., 0.])

在这个例子中,a / b 会生成 [inf, 1.0, inf],然后 div[div != div] = 0.0 将 NaN 转换为 0.0(但在这个例子中没有 NaN),div[div == float(“inf”)] = 0.0 将 inf 转换为 0.0,最终结果是 [0.0, 1.0, 0.0]。

_weighted_mean(losses, weights):

这个函数用来计算加权平均损失。它接收两个参数,losses 表示每个数据点的损失值,weights 表示每个数据点的权重。函数首先计算每个损失值乘以相应的权重,然后将所有加权损失值相加,最后除以所有权重的总和。在这个过程中,_divide_no_nan 函数被用来处理除以零的情况,确保计算的稳定性。

BasePointLoss

BasePointLoss 是一个 PyTorch 模块类,用于定义时间序列预测中的基础点损失函数。它提供了一些通用的功能和参数设置,这些功能和设置可以在具体的点损失函数(如均方误差 MSE 或平均绝对误差 MAE)中继承和使用。

主要功能和参数
以下是 BasePointLoss 类的主要功能和参数说明:

horizon_weight:
这是一个大小为 h 的张量,表示预测窗口中每个时间戳的权重。如果没有提供,它将在计算时设置为全 1 的张量。

outputsize_multiplier:
这是一个用于调整输出大小的乘数。

output_names:
这是一个列表,包含输出的名称。

该方法根据 horizon_weight 和 mask 计算每个数据点的权重。如果 horizon_weight 未设置,它将默认为全 1 的张量;否则,它会检查 horizon_weight 的长度是否与 y 的最后一维相同。

import torch# 定义示例 horizon_weight 和 mask
horizon_weight = torch.Tensor([0.1, 0.3, 0.6])
mask = torch.Tensor([[1, 0, 1], [1, 1, 0]])# 定义 y(实际值),这里只是为了展示维度,具体值不影响计算权重
y = torch.Tensor([[2, 3, 4],[1, 2, 3]])# 模拟 BasePointLoss 类的 _compute_weights 方法
def compute_weights(horizon_weight, y, mask):if mask is None:mask = torch.ones_like(y, device=y.device)if horizon_weight is None:horizon_weight = torch.ones(mask.shape[-1])else:assert mask.shape[-1] == len(horizon_weight), "horizon_weight must have same length as Y"weights = horizon_weight.clone()weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)return weights * mask# 计算权重
final_weights = compute_weights(horizon_weight, y, mask)# 打印最终权重
print("Final Weights:")
print(final_weights)

horizon_weight:
定义每个时间点的权重。例如,[0.1, 0.3, 0.6] 表示第一个时间点的权重为 0.1,第二个时间点为 0.3,第三个时间点为 0.6。

mask:
定义哪些数据点应被考虑。例如,[[1, 0, 1], [1, 1, 0]] 表示第一个样本的第二个时间点和第二个样本的第三个时间点不被考虑。

y:
实际值,仅用于展示维度。在这个例子中,假设每个样本在时间维度上有 3 个点。

compute_weights:
计算最终的权重。如果 mask 为 None,则默认为全 1。如果 horizon_weight 为 None,则默认为全 1。最终权重是 horizon_weight 和 mask 的逐元素乘积。

输出结果

Final Weights:
tensor([[0.1000, 0.0000, 0.6000],[0.1000, 0.3000, 0.0000]])

这个结果表明,权重和掩码的结合使得某些数据点被赋予了相应的权重,而被掩盖的点(即掩码为 0 的点)的权重为 0。

使用MAE 损失的原因

鲁棒性:
MAE 对于异常值的影响比均方误差 (MSE) 小,因为它计算的是绝对误差,而不是平方误差。这使得 MAE 在存在异常值或噪声的时间序列中表现更加稳健。

简单易解释:
MAE 直接衡量预测值与真实值之间的平均绝对差异,这使得其结果容易解释。它表示的是预测值与真实值之间的平均距离,这在实际应用中非常直观。

公平的误差惩罚:
MAE 对每个数据点的误差惩罚是线性的,这意味着每个预测误差都会被同等对待。相比之下,MSE 会对较大的误差赋予更高的惩罚,这在某些应用场景下可能会导致不必要的偏差。

nn.L1Loss和自定义的MAE比较下,体现一下自定义的功能

nn.L1Loss 的调用

PyTorch 内置的损失函数,用于计算预测值和真实值之间的平均绝对误差。其使用非常简单,默认情况下对每个数据点给予相同的权重,没有其他附加功能。

import torch
import torch.nn as nn# Example usage of nn.L1Loss
loss_fn = nn.L1Loss()
y = torch.tensor([1.0, 2.0, 3.0])
y_hat = torch.tensor([1.5, 2.5, 3.5])
loss = loss_fn(y_hat, y)
print(loss.item())  # Output: 0.5

class MAE(BasePointLoss) 的调用

这个自定义的 MAE 类继承自 BasePointLoss,是一个更复杂和定制化的实现。它具有以下特点:

可选的时间权重(horizon_weight):
MAE 类可以接受一个时间权重向量 horizon_weight,对预测窗口内的每个时间点赋予不同的权重。这在一些应用场景中非常有用,例如希望对特定时间点的预测误差给予更多的关注。

掩码(mask):
该类可以接受一个掩码 mask,指定哪些数据点应被纳入损失计算。这在处理缺失数据或不完整数据集时非常有用。

自定义功能:
由于继承自 BasePointLoss,这个 MAE 类可以进一步扩展和定制,以满足特定的需求。
示例实现中 _compute_weights 方法展示了如何计算权重,并在计算损失时使用这些权重。

具体示例比较

from typing import Optional, Union, Tuple
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fdef _divide_no_nan(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:"""Auxiliary funtion to handle divide by 0"""div = a / bdiv[div != div] = 0.0div[div == float("inf")] = 0.0return div
def _weighted_mean(losses, weights):"""Compute weighted mean of losses per datapoint."""return _divide_no_nan(torch.sum(losses * weights), torch.sum(weights))
class BasePointLoss(torch.nn.Module):def __init__(self, horizon_weight, outputsize_multiplier, output_names):super(BasePointLoss, self).__init__()if horizon_weight is not None:horizon_weight = torch.Tensor(horizon_weight.flatten())self.horizon_weight = horizon_weightself.outputsize_multiplier = outputsize_multiplierself.output_names = output_namesself.is_distribution_output = Falsedef domain_map(self, y_hat: torch.Tensor):return y_hat.squeeze(-1)def _compute_weights(self, y, mask):if mask is None:mask = torch.ones_like(y, device=y.device)if self.horizon_weight is None:self.horizon_weight = torch.ones(mask.shape[-1])else:assert mask.shape[-1] == len(self.horizon_weight), "horizon_weight must have same length as Y"weights = self.horizon_weight.clone()weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)return weights * maskclass MAE(BasePointLoss):"""Mean Absolute Error"""def __init__(self, horizon_weight=None):super(MAE, self).__init__(horizon_weight=horizon_weight, outputsize_multiplier=1, output_names=[""])def __call__(self,y: torch.Tensor,y_hat: torch.Tensor,mask: Union[torch.Tensor, None] = None,):losses = torch.abs(y - y_hat)weights = self._compute_weights(y=y, mask=mask)return _weighted_mean(losses=losses, weights=weights)# Example usage of custom MAE
mae_loss = MAE(horizon_weight=torch.tensor([1, 2, 3]))  # Custom weights for a 3-step horizon# 生成示例数据
batch_size = 2
horizon = 3# 随机生成输入数据和真实标签
y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y_hat = torch.tensor([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]])
mask = torch.tensor([[1, 0, 1], [1, 1, 1]])  # Consider some elements# 调用损失函数
loss = mae_loss(y, y_hat, mask=mask)# 打印 MAE 值
print("Mean Absolute Error (MAE):", loss.item())#Mean Absolute Error (MAE): 0.5

如果只是需要简单的 MAE,nn.L1Loss 就足够了;如果需要更多的控制和定制化,使用自定义的MAE(BasePointLoss)

这篇关于NeuralForecast VanillaTransformer MAE损失函数的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1034330

相关文章

MySQL中COALESCE函数示例详解

《MySQL中COALESCE函数示例详解》COALESCE是一个功能强大且常用的SQL函数,主要用来处理NULL值和实现灵活的值选择策略,能够使查询逻辑更清晰、简洁,:本文主要介绍MySQL中C... 目录语法示例1. 替换 NULL 值2. 用于字段默认值3. 多列优先级4. 结合聚合函数注意事项总结C

Java8需要知道的4个函数式接口简单教程

《Java8需要知道的4个函数式接口简单教程》:本文主要介绍Java8中引入的函数式接口,包括Consumer、Supplier、Predicate和Function,以及它们的用法和特点,文中... 目录什么是函数是接口?Consumer接口定义核心特点注意事项常见用法1.基本用法2.结合andThen链

MySQL 日期时间格式化函数 DATE_FORMAT() 的使用示例详解

《MySQL日期时间格式化函数DATE_FORMAT()的使用示例详解》`DATE_FORMAT()`是MySQL中用于格式化日期时间的函数,本文详细介绍了其语法、格式化字符串的含义以及常见日期... 目录一、DATE_FORMAT()语法二、格式化字符串详解三、常见日期时间格式组合四、业务场景五、总结一、

golang panic 函数用法示例详解

《golangpanic函数用法示例详解》在Go语言中,panic用于触发不可恢复的错误,终止函数执行并逐层向上触发defer,最终若未被recover捕获,程序会崩溃,recover用于在def... 目录1. panic 的作用2. 基本用法3. recover 的使用规则4. 错误处理建议5. 常见错

Python itertools中accumulate函数用法及使用运用详细讲解

《Pythonitertools中accumulate函数用法及使用运用详细讲解》:本文主要介绍Python的itertools库中的accumulate函数,该函数可以计算累积和或通过指定函数... 目录1.1前言:1.2定义:1.3衍生用法:1.3Leetcode的实际运用:总结 1.1前言:本文将详

轻松上手MYSQL之JSON函数实现高效数据查询与操作

《轻松上手MYSQL之JSON函数实现高效数据查询与操作》:本文主要介绍轻松上手MYSQL之JSON函数实现高效数据查询与操作的相关资料,MySQL提供了多个JSON函数,用于处理和查询JSON数... 目录一、jsON_EXTRACT 提取指定数据二、JSON_UNQUOTE 取消双引号三、JSON_KE

MySQL数据库函数之JSON_EXTRACT示例代码

《MySQL数据库函数之JSON_EXTRACT示例代码》:本文主要介绍MySQL数据库函数之JSON_EXTRACT的相关资料,JSON_EXTRACT()函数用于从JSON文档中提取值,支持对... 目录前言基本语法路径表达式示例示例 1: 提取简单值示例 2: 提取嵌套值示例 3: 提取数组中的值注意

Java function函数式接口的使用方法与实例

《Javafunction函数式接口的使用方法与实例》:本文主要介绍Javafunction函数式接口的使用方法与实例,函数式接口如一支未完成的诗篇,用Lambda表达式作韵脚,将代码的机械美感... 目录引言-当代码遇见诗性一、函数式接口的生物学解构1.1 函数式接口的基因密码1.2 六大核心接口的形态学

Oracle的to_date()函数详解

《Oracle的to_date()函数详解》Oracle的to_date()函数用于日期格式转换,需要注意Oracle中不区分大小写的MM和mm格式代码,应使用mi代替分钟,此外,Oracle还支持毫... 目录oracle的to_date()函数一.在使用Oracle的to_date函数来做日期转换二.日

C++11的函数包装器std::function使用示例

《C++11的函数包装器std::function使用示例》C++11引入的std::function是最常用的函数包装器,它可以存储任何可调用对象并提供统一的调用接口,以下是关于函数包装器的详细讲解... 目录一、std::function 的基本用法1. 基本语法二、如何使用 std::function