PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

2024-06-24 01:44

本文主要是介绍PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • nn.MSELoss() 均方误差损失函数
    • 参数
    • 数学公式
      • 元素版本
    • 要点
    • 附录
  • 参考链接

nn.MSELoss() 均方误差损失函数

torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input x x x and target y y y.

计算输入和目标之间每个元素的均方误差(平方 L2 范数)。

参数

  • size_average (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失在批次中的每个损失元素上取平均(True);否则(False),在每个小批次中对损失求和。
    • reduceFalse 时忽略该参数。
    • 默认值是 True
  • reduce (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失根据 size_average 参数进行平均或求和。
    • reduceFalse 时,返回每个批次元素的损失,并忽略 size_average 参数。
    • 默认值是 True
  • reduction (str, 可选):
    • 指定应用于输出的归约方式。
    • 可选值为 'none''mean''sum'
      • 'none':不进行归约。
      • 'mean':输出的和除以输出的元素总数。
      • 'sum':输出的元素求和。
    • 注意:size_averagereduce 参数正在被弃用,同时指定这些参数中的任何一个都会覆盖 reduction 参数。
    • 默认值是 'mean'

数学公式

附录部分会验证下述公式和代码的一致性。

假设有 N N N 个样本,每个样本的输入为 x n x_n xn,目标为 y n y_n yn。均方误差损失的计算步骤如下:

  1. 单个样本的损失
    计算每个样本的均方误差:
    l n = ( x n − y n ) 2 l_n = (x_n - y_n)^2 ln=(xnyn)2
    其中 l n l_n ln 是第 n n n 个样本的损失。
  2. 总损失
    计算所有样本的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ n = 1 N l n = 1 N ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \frac{1}{N} \sum_{n=1}^{N} l_n = \frac{1}{N} \sum_{n=1}^{N} (x_n - y_n)^2 L=N1n=1Nln=N1n=1N(xnyn)2
    如果 reduction 参数为 'sum',总损失为所有样本损失的和:
    L = ∑ n = 1 N l n = ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \sum_{n=1}^{N} l_n = \sum_{n=1}^{N} (x_n - y_n)^2 L=n=1Nln=n=1N(xnyn)2
    如果 reduction 参数为 'none',则返回每个样本的损失 l n l_n ln 组成的张量:
    L = [ l 1 , l 2 , … , l N ] = [ ( x 1 − y 1 ) 2 , ( x 2 − y 2 ) 2 , … , ( x N − y N ) 2 ] \mathcal{L} = [l_1, l_2, \ldots, l_N] = [(x_1 - y_1)^2, (x_2 - y_2)^2, \ldots, (x_N - y_N)^2] L=[l1,l2,,lN]=[(x1y1)2,(x2y2)2,,(xNyN)2]

元素版本

假设输入张量 x \mathbf{x} x 和目标张量 y \mathbf{y} y 具有相同的形状,每个张量包含 N N N 个元素。均方误差损失的计算步骤如下:

  1. 单个元素的损失
    计算每个元素的均方误差:
    l i j = ( x i j − y i j ) 2 l_{ij} = (x_{ij} - y_{ij})^2 lij=(xijyij)2
    其中 l i j l_{ij} lij 是输入张量和目标张量在位置 ( i , j ) (i, j) (i,j) 的元素损失。
  2. 总损失
    计算所有元素的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ i , j l i j = 1 N ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \frac{1}{N} \sum_{i,j} l_{ij} = \frac{1}{N} \sum_{i,j} (x_{ij} - y_{ij})^2 L=N1i,jlij=N1i,j(xijyij)2
    如果 reduction 参数为 'sum',总损失为所有元素损失的和:
    L = ∑ i , j l i j = ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \sum_{i,j} l_{ij} = \sum_{i,j} (x_{ij} - y_{ij})^2 L=i,jlij=i,j(xijyij)2
    如果 reduction 参数为 'none',则返回每个元素的损失 l i j l_{ij} lij 组成的张量:
    L = { l i j } = { ( x i j − y i j ) 2 } \mathcal{L} = \{l_{ij}\} = \{(x_{ij} - y_{ij})^2 \} L={lij}={(xijyij)2}

要点

  1. nn.MSELoss() 接受的输入和目标应具有相同的形状和类型。
    使用示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失
    criterion = nn.MSELoss()
    loss = criterion(input, target)print(f"Loss using nn.MSELoss: {loss.item()}")
    
    >>> Loss using nn.MSELoss: 0.25
    
  2. nn.MSELoss()reduction 参数指定了如何归约输出损失。默认值是 'mean',计算的是所有样本的平均损失。
    • 如果 reduction 参数为 'mean',损失是所有样本损失的平均值。
    • 如果 reduction 参数为 'sum',损失是所有样本损失的和。
    • 如果 reduction 参数为 'none',则返回每个样本的损失组成的张量。
      代码示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失(reduction='mean')
    criterion_mean = nn.MSELoss(reduction='mean')
    loss_mean = criterion_mean(input, target)
    print(f"Loss with reduction='mean': {loss_mean.item()}")# 使用 nn.MSELoss 计算损失(reduction='sum')
    criterion_sum = nn.MSELoss(reduction='sum')
    loss_sum = criterion_sum(input, target)
    print(f"Loss with reduction='sum': {loss_sum.item()}")# 使用 nn.MSELoss 计算损失(reduction='none')
    criterion_none = nn.MSELoss(reduction='none')
    loss_none = criterion_none(input, target)
    print(f"Loss with reduction='none': {loss_none}")
    
    >>> Loss with reduction='mean': 0.25
    >>> Loss with reduction='sum': 1.0
    >>> Loss with reduction='none': tensor([[0.2500, 0.2500],[0.2500, 0.2500]], grad_fn=<MseLossBackward0>)
    

附录

用于验证数学公式和函数实际运行的一致性

import torch
import torch.nn.functional as F# 假设有两个样本,每个样本有两个维度
input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 根据公式实现均方误差损失
def mse_loss(input, target):return ((input - target) ** 2).mean()# 使用 nn.MSELoss 计算损失
criterion = torch.nn.MSELoss(reduction='mean')
loss_torch = criterion(input, target)# 使用根据公式实现的均方误差损失
loss_custom = mse_loss(input, target)# 打印结果
print("PyTorch 计算的均方误差损失:", loss_torch.item())
print("根据公式实现的均方误差损失:", loss_custom.item())# 验证结果是否相等
assert torch.isclose(loss_torch, loss_custom), "数学公式验证失败"
>>> PyTorch 计算的均方误差损失: 0.25
>>> 根据公式实现的均方误差损失: 0.25

输出没有抛出 AssertionError,验证通过。

参考链接

MSELoss - Docs

这篇关于PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088849

相关文章

JAVA系统中Spring Boot应用程序的配置文件application.yml使用详解

《JAVA系统中SpringBoot应用程序的配置文件application.yml使用详解》:本文主要介绍JAVA系统中SpringBoot应用程序的配置文件application.yml的... 目录文件路径文件内容解释1. Server 配置2. Spring 配置3. Logging 配置4. Ma

mac中资源库在哪? macOS资源库文件夹详解

《mac中资源库在哪?macOS资源库文件夹详解》经常使用Mac电脑的用户会发现,找不到Mac电脑的资源库,我们怎么打开资源库并使用呢?下面我们就来看看macOS资源库文件夹详解... 在 MACOS 系统中,「资源库」文件夹是用来存放操作系统和 App 设置的核心位置。虽然平时我们很少直接跟它打交道,但了

关于Maven中pom.xml文件配置详解

《关于Maven中pom.xml文件配置详解》pom.xml是Maven项目的核心配置文件,它描述了项目的结构、依赖关系、构建配置等信息,通过合理配置pom.xml,可以提高项目的可维护性和构建效率... 目录1. POM文件的基本结构1.1 项目基本信息2. 项目属性2.1 引用属性3. 项目依赖4. 构

Rust 数据类型详解

《Rust数据类型详解》本文介绍了Rust编程语言中的标量类型和复合类型,标量类型包括整数、浮点数、布尔和字符,而复合类型则包括元组和数组,标量类型用于表示单个值,具有不同的表示和范围,本文介绍的非... 目录一、标量类型(Scalar Types)1. 整数类型(Integer Types)1.1 整数字

Java操作ElasticSearch的实例详解

《Java操作ElasticSearch的实例详解》Elasticsearch是一个分布式的搜索和分析引擎,广泛用于全文搜索、日志分析等场景,本文将介绍如何在Java应用中使用Elastics... 目录简介环境准备1. 安装 Elasticsearch2. 添加依赖连接 Elasticsearch1. 创

Redis缓存问题与缓存更新机制详解

《Redis缓存问题与缓存更新机制详解》本文主要介绍了缓存问题及其解决方案,包括缓存穿透、缓存击穿、缓存雪崩等问题的成因以及相应的预防和解决方法,同时,还详细探讨了缓存更新机制,包括不同情况下的缓存更... 目录一、缓存问题1.1 缓存穿透1.1.1 问题来源1.1.2 解决方案1.2 缓存击穿1.2.1

PyTorch使用教程之Tensor包详解

《PyTorch使用教程之Tensor包详解》这篇文章介绍了PyTorch中的张量(Tensor)数据结构,包括张量的数据类型、初始化、常用操作、属性等,张量是PyTorch框架中的核心数据结构,支持... 目录1、张量Tensor2、数据类型3、初始化(构造张量)4、常用操作5、常用属性5.1 存储(st

Python 中 requests 与 aiohttp 在实际项目中的选择策略详解

《Python中requests与aiohttp在实际项目中的选择策略详解》本文主要介绍了Python爬虫开发中常用的两个库requests和aiohttp的使用方法及其区别,通过实际项目案... 目录一、requests 库二、aiohttp 库三、requests 和 aiohttp 的比较四、requ

VUE动态绑定class类的三种常用方式及适用场景详解

《VUE动态绑定class类的三种常用方式及适用场景详解》文章介绍了在实际开发中动态绑定class的三种常见情况及其解决方案,包括根据不同的返回值渲染不同的class样式、给模块添加基础样式以及根据设... 目录前言1.动态选择class样式(对象添加:情景一)2.动态添加一个class样式(字符串添加:情

Python在固定文件夹批量创建固定后缀的文件(方法详解)

《Python在固定文件夹批量创建固定后缀的文件(方法详解)》文章讲述了如何使用Python批量创建后缀为.md的文件夹,生成100个,代码中需要修改的路径、前缀和后缀名,并提供了注意事项和代码示例,... 目录1. python需求的任务2. Python代码的实现3. 代码修改的位置4. 运行结果5.