深度学习代码|MSE损失的代码实现

2024-04-01 04:12

本文主要是介绍深度学习代码|MSE损失的代码实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、MSE代码手动实现
    • (一)导入相关库
    • (二)计算均方误差损失函数
    • (三)示例使用
  • 二、Pytorch中MSELoss函数的接口
    • (一)参数
    • (二)使用示例
    • (三)反向传播


一、MSE代码手动实现

(一)导入相关库

NumPy 是 Python 语言的一个第三方库,支持大量高维度数组与矩阵运算。此外,NumPy 也针对数组运算提供大量的数学函数。机器学习涉及到大量对数组的变换和运算,NumPy 就成了必不可少的工具之一。

import numpy as np

(二)计算均方误差损失函数

参数:

  • y_true:真实值的数组,可以是一维或多维
  • y_pred:预测值的数组,形状应与y_true相同

返回:

  • loss:计算得到的loss值
def mse_loss(y_true,y_pred):#计算真实值和预测值之间的差异diff=y_true-y_pred#计算差值的平方sq_diff=np.square(diff)#计算均方误差,即平方差的平均值#使用np.mean计算平均值,axis=0表示沿着第一个轴(通常是样本维度)计算loss=np.mean(sq_diff,axis=0)return loss

(三)示例使用

y_true=np.arrray([1,2,3,4])
y_pred=np.array([1.5,2.1,2.9,4.2])loss=mse_loss(y_true,y_pred)
print("MSE Loss:",loss)

二、Pytorch中MSELoss函数的接口

该函数默认用于计算两个输入对应元素差值平方和的均值。具体地,在深度学习中,可以使用该函数用来计算两个特征图的相似性。

torch.nn.MSELoss(size_average=None, reduce=None, reduction=‘mean’)

(一)参数

  • 当reduce=True时,若size_average=True,则返回一个batch中所有样本损失的均值,结果为标量。注意,对于MESLoss函数来说,首先对该batch中的所有样本损失进行逐元素均值操作,然后对得到N个值再进行均值操作即得到返回值(假设批大小为N,即该batch中共有N个样本)
  • 当reduce=True时,若size_average=False,则返回一个batch中所有样本损失的和,结果为标量。注意,对于MESLoss函数来说,首先对该batch中的所有样本损失进行逐元素求和操作,然后对得到N个值再进行求和操作即得到返回值(假设批大小为N,即该batch中共有N个样本)
  • 当reduce=False时,则size_average参数失效,即无论size_average参数为False还是True,效果都是一样的。此时,函数返回的是一个batch中每个样本的损失,结果为向量。
  • reduction参数包含了reduce和size_average参数的双重含义,这也是为什么reduce和size_average参数将在后续版本中被弃用的原因。

(二)使用示例

首先假设有三个数据样本分别经过神经网络运算,得到三个输出与其标签分别是:

y_pre = torch.Tensor([[1, 2, 3],[2, 1, 3],[3, 1, 2]])y_label = torch.Tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1]])

当reduction=‘none’时,相当于reduce=False;

criterion1 = nn.MSELoss(reduction="none")
loss1 = criterion1(x, y)
print(loss1)

输出结果为:

tensor([[0., 4., 9.],
[4., 0., 9.],
[9., 1., 1.]])

当reduction=‘sum’时,相当于reduce=True且size_average=False;

criterion2 = nn.MSELoss(reduction="mean")
loss2 = criterion2(x, y)
print(loss2)

输出结果为:

tensor(4.1111)

当reduction=‘mean’时,相当于reduce=True且size_average=True;

criterion3 = nn.MSELoss(reduction="sum")
loss3 = criterion3(x, y)
print(loss3)

输出结果为:

tensor(37.)

(三)反向传播

一般在反向传播时,都是先求loss,再使用loss.backward()求loss对每个参数 w_ij和b的偏导数(也可以理解为梯度)。但是只有标量才能执行backward()函数,因此在反向传播中reduction不能设为"none"。

  • 若设置为"sum",则有Loss=loss_1+loss_2+loss_3,表示总的Loss由每个实例的loss_i构成,在通过Loss求梯度时,将每个loss_i的梯度也都考虑进去了。
  • 若设置为"mean",则相比"sum"相当于Loss变成了Loss*(1/i),这在参数更新时影响不大,因为有学习率a的存在。

如果只想在batch上做平均,可以这样写:

loss_fn = torch.nn.MSELoss(reduction="sum")
loss = loss_fn(pred, y) / pred.size(0)

参考:
手撕算法面试二,手撕MSE损失
pytorch官网介绍
【PyTorch】MSELoss的详细理解(含源代码)

这篇关于深度学习代码|MSE损失的代码实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/866178

相关文章

MySQL中查找重复值的实现

《MySQL中查找重复值的实现》查找重复值是一项常见需求,比如在数据清理、数据分析、数据质量检查等场景下,我们常常需要找出表中某列或多列的重复值,具有一定的参考价值,感兴趣的可以了解一下... 目录技术背景实现步骤方法一:使用GROUP BY和HAVING子句方法二:仅返回重复值方法三:返回完整记录方法四:

IDEA中新建/切换Git分支的实现步骤

《IDEA中新建/切换Git分支的实现步骤》本文主要介绍了IDEA中新建/切换Git分支的实现步骤,通过菜单创建新分支并选择是否切换,创建后在Git详情或右键Checkout中切换分支,感兴趣的可以了... 前提:项目已被Git托管1、点击上方栏Git->NewBrancjsh...2、输入新的分支的

Python实现对阿里云OSS对象存储的操作详解

《Python实现对阿里云OSS对象存储的操作详解》这篇文章主要为大家详细介绍了Python实现对阿里云OSS对象存储的操作相关知识,包括连接,上传,下载,列举等功能,感兴趣的小伙伴可以了解下... 目录一、直接使用代码二、详细使用1. 环境准备2. 初始化配置3. bucket配置创建4. 文件上传到os

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

java实现docker镜像上传到harbor仓库的方式

《java实现docker镜像上传到harbor仓库的方式》:本文主要介绍java实现docker镜像上传到harbor仓库的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 前 言2. 编写工具类2.1 引入依赖包2.2 使用当前服务器的docker环境推送镜像2.2

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Java easyExcel实现导入多sheet的Excel

《JavaeasyExcel实现导入多sheet的Excel》这篇文章主要为大家详细介绍了如何使用JavaeasyExcel实现导入多sheet的Excel,文中的示例代码讲解详细,感兴趣的小伙伴可... 目录1.官网2.Excel样式3.代码1.官网easyExcel官网2.Excel样式3.代码