深度学习代码|MSE损失的代码实现

2024-04-01 04:12

本文主要是介绍深度学习代码|MSE损失的代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、MSE代码手动实现
    • (一)导入相关库
    • (二)计算均方误差损失函数
    • (三)示例使用
  • 二、Pytorch中MSELoss函数的接口
    • (一)参数
    • (二)使用示例
    • (三)反向传播


一、MSE代码手动实现

(一)导入相关库

NumPy 是 Python 语言的一个第三方库,支持大量高维度数组与矩阵运算。此外,NumPy 也针对数组运算提供大量的数学函数。机器学习涉及到大量对数组的变换和运算,NumPy 就成了必不可少的工具之一。

import numpy as np

(二)计算均方误差损失函数

参数:

  • y_true:真实值的数组,可以是一维或多维
  • y_pred:预测值的数组,形状应与y_true相同

返回:

  • loss:计算得到的loss值
def mse_loss(y_true,y_pred):#计算真实值和预测值之间的差异diff=y_true-y_pred#计算差值的平方sq_diff=np.square(diff)#计算均方误差,即平方差的平均值#使用np.mean计算平均值,axis=0表示沿着第一个轴(通常是样本维度)计算loss=np.mean(sq_diff,axis=0)return loss

(三)示例使用

y_true=np.arrray([1,2,3,4])
y_pred=np.array([1.5,2.1,2.9,4.2])loss=mse_loss(y_true,y_pred)
print("MSE Loss:",loss)

二、Pytorch中MSELoss函数的接口

该函数默认用于计算两个输入对应元素差值平方和的均值。具体地,在深度学习中,可以使用该函数用来计算两个特征图的相似性。

torch.nn.MSELoss(size_average=None, reduce=None, reduction=‘mean’)

(一)参数

  • 当reduce=True时,若size_average=True,则返回一个batch中所有样本损失的均值,结果为标量。注意,对于MESLoss函数来说,首先对该batch中的所有样本损失进行逐元素均值操作,然后对得到N个值再进行均值操作即得到返回值(假设批大小为N,即该batch中共有N个样本)
  • 当reduce=True时,若size_average=False,则返回一个batch中所有样本损失的和,结果为标量。注意,对于MESLoss函数来说,首先对该batch中的所有样本损失进行逐元素求和操作,然后对得到N个值再进行求和操作即得到返回值(假设批大小为N,即该batch中共有N个样本)
  • 当reduce=False时,则size_average参数失效,即无论size_average参数为False还是True,效果都是一样的。此时,函数返回的是一个batch中每个样本的损失,结果为向量。
  • reduction参数包含了reduce和size_average参数的双重含义,这也是为什么reduce和size_average参数将在后续版本中被弃用的原因。

(二)使用示例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],[2, 1, 3],[3, 1, 2]])y_label = torch.Tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1]])

当reduction=‘none’时,相当于reduce=False;

criterion1 = nn.MSELoss(reduction="none")
loss1 = criterion1(x, y)
print(loss1)

输出结果为:

tensor([[0., 4., 9.],
[4., 0., 9.],
[9., 1., 1.]])

当reduction=‘sum’时,相当于reduce=True且size_average=False;

criterion2 = nn.MSELoss(reduction="mean")
loss2 = criterion2(x, y)
print(loss2)

输出结果为:

tensor(4.1111)

当reduction=‘mean’时,相当于reduce=True且size_average=True;

criterion3 = nn.MSELoss(reduction="sum")
loss3 = criterion3(x, y)
print(loss3)

输出结果为:

tensor(37.)

(三)反向传播

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。但是只有标量才能执行backward()函数,因此在反向传播中reduction不能设为"none"。

  • 若设置为"sum",则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。
  • 若设置为"mean",则相比"sum"相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, y) / pred.size(0)

参考:
手撕算法面试二,手撕MSE损失
pytorch官网介绍
【PyTorch】MSELoss的详细理解(含源代码)

这篇关于深度学习代码|MSE损失的代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/866178

相关文章

Oracle查询优化之高效实现仅查询前10条记录的方法与实践

《Oracle查询优化之高效实现仅查询前10条记录的方法与实践》:本文主要介绍Oracle查询优化之高效实现仅查询前10条记录的相关资料,包括使用ROWNUM、ROW_NUMBER()函数、FET... 目录1. 使用 ROWNUM 查询2. 使用 ROW_NUMBER() 函数3. 使用 FETCH FI

Python脚本实现自动删除C盘临时文件夹

《Python脚本实现自动删除C盘临时文件夹》在日常使用电脑的过程中,临时文件夹往往会积累大量的无用数据,占用宝贵的磁盘空间,下面我们就来看看Python如何通过脚本实现自动删除C盘临时文件夹吧... 目录一、准备工作二、python脚本编写三、脚本解析四、运行脚本五、案例演示六、注意事项七、总结在日常使用

Java实现Excel与HTML互转

《Java实现Excel与HTML互转》Excel是一种电子表格格式,而HTM则是一种用于创建网页的标记语言,虽然两者在用途上存在差异,但有时我们需要将数据从一种格式转换为另一种格式,下面我们就来看看... Excel是一种电子表格格式,广泛用于数据处理和分析,而HTM则是一种用于创建网页的标记语言。虽然两

Java中Springboot集成Kafka实现消息发送和接收功能

《Java中Springboot集成Kafka实现消息发送和接收功能》Kafka是一个高吞吐量的分布式发布-订阅消息系统,主要用于处理大规模数据流,它由生产者、消费者、主题、分区和代理等组件构成,Ka... 目录一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者一、Kaf

使用Python实现在Word中添加或删除超链接

《使用Python实现在Word中添加或删除超链接》在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能,本文将为大家介绍一下Python如何实现在Word中添加或... 在Word文档中,超链接是一种将文本或图像连接到其他文档、网页或同一文档中不同部分的功能。通过添加超

windos server2022里的DFS配置的实现

《windosserver2022里的DFS配置的实现》DFS是WindowsServer操作系统提供的一种功能,用于在多台服务器上集中管理共享文件夹和文件的分布式存储解决方案,本文就来介绍一下wi... 目录什么是DFS?优势:应用场景:DFS配置步骤什么是DFS?DFS指的是分布式文件系统(Distr

NFS实现多服务器文件的共享的方法步骤

《NFS实现多服务器文件的共享的方法步骤》NFS允许网络中的计算机之间共享资源,客户端可以透明地读写远端NFS服务器上的文件,本文就来介绍一下NFS实现多服务器文件的共享的方法步骤,感兴趣的可以了解一... 目录一、简介二、部署1、准备1、服务端和客户端:安装nfs-utils2、服务端:创建共享目录3、服

C#使用yield关键字实现提升迭代性能与效率

《C#使用yield关键字实现提升迭代性能与效率》yield关键字在C#中简化了数据迭代的方式,实现了按需生成数据,自动维护迭代状态,本文主要来聊聊如何使用yield关键字实现提升迭代性能与效率,感兴... 目录前言传统迭代和yield迭代方式对比yield延迟加载按需获取数据yield break显式示迭

Python实现高效地读写大型文件

《Python实现高效地读写大型文件》Python如何读写的是大型文件,有没有什么方法来提高效率呢,这篇文章就来和大家聊聊如何在Python中高效地读写大型文件,需要的可以了解下... 目录一、逐行读取大型文件二、分块读取大型文件三、使用 mmap 模块进行内存映射文件操作(适用于大文件)四、使用 pand

python实现pdf转word和excel的示例代码

《python实现pdf转word和excel的示例代码》本文主要介绍了python实现pdf转word和excel的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价... 目录一、引言二、python编程1,PDF转Word2,PDF转Excel三、前端页面效果展示总结一