深度学习常见概念解释(四)——损失函数定义,作用与种类(附公式和代码)

本文主要是介绍深度学习常见概念解释(四)——损失函数定义,作用与种类(附公式和代码),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

损失函数

  • 前言
  • 定义
  • 作用
  • 种类
    • 1. 均方误差损失(Mean Squared Error Loss,MSE)
      • 公式
      • 特点和优点
      • 缺点
      • 使用场景
      • 示例代码
      • 在机器学习框架中的使用
      • 总结
    • 2. 交叉熵损失(Cross-Entropy Loss)
      • 公式
      • 特点和优点
      • 使用场景
      • 示例代码
      • 在机器学习框架中的使用
      • 总结
  • 总结

前言

在机器学习和深度学习中,损失函数(Loss Function)起着至关重要的作用。它是模型优化过程中不可或缺的一部分,用于衡量模型预测值与真实值之间的差异。选择合适的损失函数不仅可以帮助模型更好地拟合数据,还能反映任务的特性,提高模型的性能和鲁棒性。本文将详细介绍损失函数的定义、作用及常见种类,并通过具体的示例代码展示如何在实际应用中使用这些损失函数。

定义

损失函数(loss function)是在机器学习和深度学习中用来衡量模型预测值与真实值之间差异的函数。它通常表示为一个标量值,用来评估模型在训练数据上的表现。

作用

  1. 衡量预测值与真实值之间的差异: 损失函数衡量了模型在给定数据上的表现,即模型对于输入数据的预测与实际标签之间的差异程度。通过最小化损失函数,模型可以更好地拟合训练数据,提高预测的准确性。

  2. 指导模型优化: 在训练过程中,损失函数是优化算法的目标函数,模型的参数通过最小化损失函数来调整,使得模型能够更好地拟合训练数据。常见的优化算法包括梯度下降(Gradient Descent)及其变种,它们通过计算损失函数的梯度来更新模型参数。

  3. 反映任务的特性: 不同任务和模型需要选择不同的损失函数。例如,分类任务常用的损失函数包括交叉熵损失(Cross-Entropy Loss),回归任务常用的损失函数包括均方误差损失(Mean Squared Error Loss)。选择合适的损失函数能够更好地反映任务的特性,有助于提高模型的性能。

  4. 处理不平衡数据: 在某些情况下,数据可能存在类别不平衡或者噪声,选择合适的损失函数可以帮助模型更好地处理这些情况,提高模型的鲁棒性。

总的来说,损失函数在机器学习和深度学习中扮演着至关重要的角色,它不仅指导模型的训练过程,还反映了模型对于任务的表现和适应能力。

种类

在机器学习和深度学习中,常见的损失函数包括以下几种:

1. 均方误差损失(Mean Squared Error Loss,MSE)

均方误差损失(Mean Squared Error Loss,简称 MSE)是一种常用的回归模型损失函数,用于衡量预测值与真实值之间的差异。MSE 的计算方式是将每个预测值与真实值之间的差值平方,然后求这些差值平方的平均值。

公式

MSE = 1 2 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{2n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=2n1i=1n(yiy^i)2
其中:

  • n n n 是数据点的数量。
  • y i y_i yi 是第 i i i 个真实值。
  • y ^ i \hat{y}_i y^i 是第 i i i 个预测值。

特点和优点

  1. 平滑性:MSE 损失函数是连续和可微的,这使得它非常适合用于梯度下降等优化算法。
  2. 凸性:MSE 是一个凸函数,这意味着在大多数情况下,它只有一个全局最小值,这对优化问题非常重要。
  3. 简单性:MSE 的公式简单,计算方便,容易实现。

缺点

  1. 对异常值敏感:由于误差被平方,MSE 对异常值(outliers)特别敏感。如果数据集中存在极端值,这些值会对整体误差有很大影响,导致模型不稳定。
  2. 不适用于分类问题:MSE 主要用于回归问题,对于分类问题,通常使用交叉熵损失等其他损失函数。

使用场景

MSE 广泛用于各种回归问题中,例如:

  • 预测房价
  • 股票价格预测
  • 气温预测
  • 机器学习模型中的损失计算

示例代码

import numpy as np# 定义真实值和预测值
y_true = np.array([1.5, 2.0, 3.5, 4.0, 5.5])
y_pred = np.array([1.4, 2.1, 3.6, 3.9, 5.8])# 计算均方误差
mse = np.mean((y_true - y_pred) ** 2)
print(f"Mean Squared Error: {mse}")

在机器学习框架中的使用

在流行的机器学习框架中,如 TensorFlow 和 PyTorch,均方误差损失通常作为内置函数提供,使用非常方便。

import torch
import torch.nn as nn# 定义真实值和预测值
y_true = torch.tensor([1.5, 2.0, 3.5, 4.0, 5.5])
y_pred = torch.tensor([1.4, 2.1, 3.6, 3.9, 5.8])# 定义 MSE 损失函数
mse_loss = nn.MSELoss()# 计算损失
loss = mse_loss(y_pred, y_true)
print(f"Mean Squared Error Loss: {loss.item()}")

总结

均方误差损失(MSE)是衡量回归模型性能的一种标准方法,通过计算预测值与真实值之间的平方误差平均值来评估模型的准确性。尽管它对异常值敏感,但其简单性和计算效率使其在各种回归任务中广泛应用。

2. 交叉熵损失(Cross-Entropy Loss)

交叉熵损失(Cross-Entropy Loss)是一种常用于分类任务中的损失函数,特别适用于多类别分类问题。交叉熵损失用于衡量预测的概率分布与真实分布之间的差异。它通过计算真实标签和预测概率之间的不确定性来衡量模型的性能。

公式

  1. 对于二分类问题,二分类交叉熵损失(Binary Cross-Entropy Loss, BCE)的公式如下:
    CE = − ( y log ⁡ ( p ) + ( 1 − y ) log ⁡ ( 1 − p ) ) \text{CE} = - \left( y \log(p) + (1 - y) \log(1 - p) \right) CE=(ylog(p)+(1y)log(1p))
    其中:

    • y y y 是真实标签,取值为 0 或 1。
    • p p p 是预测为类别 1 的概率。
  2. 对于多分类问题,多分类交叉熵损失(Categorical Cross-Entropy Loss, CCE)的公式为:
    CE = − ∑ i = 1 n y i log ⁡ ( p i ) \text{CE} = - \sum_{i=1}^{n} y_i \log(p_i) CE=i=1nyilog(pi)
    其中:

    • n n n 是类别的数量。
    • y i y_i yi 是真实标签,如果样本属于第 i i i类,则 y i = 1 y_i = 1 yi=1 ,否则 y i = 0 y_i = 0 yi=0
    • p i p_i pi 是模型预测样本属于第 i i i类的概率。

PS.:二分类交叉熵损失(Binary Cross-Entropy Loss)也被称为对数损失(Log Loss)。
PPS. 注意在正式计算的时候需要把所有的误差值加起来取平均值(具体步骤见下面的示例代码)。

特点和优点

  1. 概率输出:交叉熵损失函数使用预测的概率分布,这使得它特别适用于分类问题。
  2. 敏感性:它对错误分类的惩罚较大,尤其是在预测概率较高但实际类别不匹配的情况下。
  3. 凸性:交叉熵损失通常是凸的,这有助于优化算法找到全局最优解。

使用场景

交叉熵损失广泛用于各种分类问题中,例如:

  • 图像分类
  • 文本分类
  • 语音识别
  • 机器翻译

示例代码

import numpy as np# 二分类问题
def binary_cross_entropy(y_true, y_pred):y_true = np.array(y_true)y_pred = np.array(y_pred)return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))# 示例数据
y_true = [1, 0, 1, 1, 0]
y_pred = [0.9, 0.1, 0.8, 0.7, 0.2]# 计算二分类交叉熵损失
loss = binary_cross_entropy(y_true, y_pred)
print(f"Binary Cross-Entropy Loss: {loss}")# 多分类问题
def categorical_cross_entropy(y_true, y_pred):y_true = np.array(y_true)y_pred = np.array(y_pred)return -np.sum(y_true * np.log(y_pred)) / y_true.shape[0]# 示例数据
y_true = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
y_pred = [[0.7, 0.2, 0.1], [0.1, 0.8, 0.1], [0.2, 0.2, 0.6]]# 计算多分类交叉熵损失
loss = categorical_cross_entropy(y_true, y_pred)
print(f"Categorical Cross-Entropy Loss: {loss}")

在机器学习框架中的使用

在流行的机器学习框架中,如 TensorFlow 和 PyTorch,交叉熵损失通常作为内置函数提供,使用非常方便。

import torch
import torch.nn as nn# 定义真实标签和预测概率
y_true = torch.tensor([2, 0, 1])
y_pred = torch.tensor([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1], [0.2, 0.6, 0.2]])# 定义交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 计算损失
loss = criterion(y_pred, y_true)
print(f"Cross-Entropy Loss: {loss.item()}")

总结

交叉熵损失(Cross-Entropy Loss)是分类问题中常用的损失函数,通过衡量预测的概率分布与真实分布之间的差异来评估模型性能。它对错误分类的惩罚较大,并且使用概率输出,非常适合分类任务。流行的深度学习框架通常提供了内置的交叉熵损失函数,方便用户使用。

总结

损失函数在机器学习和深度学习中扮演着至关重要的角色。它不仅指导模型的训练过程,还反映了模型对于任务的表现和适应能力。选择合适的损失函数是模型优化的重要一步,能够显著提高模型的性能和鲁棒性。希望通过本文的介绍,读者能够对损失函数有一个全面的了解,并在实际项目中选择和应用合适的损失函数,这对于模型的训练和性能至关重要。

这篇关于深度学习常见概念解释(四)——损失函数定义,作用与种类(附公式和代码)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1088229

相关文章

嵌入式软件常见的笔试题(c)

找工作的事情告一段落,现在把一些公司常见的笔试题型整理一下,本人主要是找嵌入式软件方面的工作,笔试的也主要是C语言、数据结构,大体上都比较基础,但是得早作准备,才会占得先机。   1:整型数求反 2:字符串求反,字符串加密,越界问题 3:字符串逆序,两端对调;字符串逆序,指针法 4:递归求n! 5:不用库函数,比较两个字符串的大小 6:求0-3000中含有9和2的全部数之和 7

51单片机学习记录———定时器

文章目录 前言一、定时器介绍二、STC89C52定时器资源三、定时器框图四、定时器模式五、定时器相关寄存器六、定时器练习 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、定时器介绍 定时器介绍:51单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 定时器作用: 1.用于计数系统,可

问题:第一次世界大战的起止时间是 #其他#学习方法#微信

问题:第一次世界大战的起止时间是 A.1913 ~1918 年 B.1913 ~1918 年 C.1914 ~1918 年 D.1914 ~1919 年 参考答案如图所示

[word] word设置上标快捷键 #学习方法#其他#媒体

word设置上标快捷键 办公中,少不了使用word,这个是大家必备的软件,今天给大家分享word设置上标快捷键,希望在办公中能帮到您! 1、添加上标 在录入一些公式,或者是化学产品时,需要添加上标内容,按下快捷键Ctrl+shift++就能将需要的内容设置为上标符号。 word设置上标快捷键的方法就是以上内容了,需要的小伙伴都可以试一试呢!

AssetBundle学习笔记

AssetBundle是unity自定义的资源格式,通过调用引擎的资源打包接口对资源进行打包成.assetbundle格式的资源包。本文介绍了AssetBundle的生成,使用,加载,卸载以及Unity资源更新的一个基本步骤。 目录 1.定义: 2.AssetBundle的生成: 1)设置AssetBundle包的属性——通过编辑器界面 补充:分组策略 2)调用引擎接口API

Javascript高级程序设计(第四版)--学习记录之变量、内存

原始值与引用值 原始值:简单的数据即基础数据类型,按值访问。 引用值:由多个值构成的对象即复杂数据类型,按引用访问。 动态属性 对于引用值而言,可以随时添加、修改和删除其属性和方法。 let person = new Object();person.name = 'Jason';person.age = 42;console.log(person.name,person.age);//'J

大学湖北中医药大学法医学试题及答案,分享几个实用搜题和学习工具 #微信#学习方法#职场发展

今天分享拥有拍照搜题、文字搜题、语音搜题、多重搜题等搜题模式,可以快速查找问题解析,加深对题目答案的理解。 1.快练题 这是一个网站 找题的网站海量题库,在线搜题,快速刷题~为您提供百万优质题库,直接搜索题库名称,支持多种刷题模式:顺序练习、语音听题、本地搜题、顺序阅读、模拟考试、组卷考试、赶快下载吧! 2.彩虹搜题 这是个老公众号了 支持手写输入,截图搜题,详细步骤,解题必备

uniapp接入微信小程序原生代码配置方案(优化版)

uniapp项目需要把微信小程序原生语法的功能代码嵌套过来,无需把原生代码转换为uniapp,可以配置拷贝的方式集成过来 1、拷贝代码包到src目录 2、vue.config.js中配置原生代码包直接拷贝到编译目录中 3、pages.json中配置分包目录,原生入口组件的路径 4、manifest.json中配置分包,使用原生组件 5、需要把原生代码包里的页面修改成组件的方

公共筛选组件(二次封装antd)支持代码提示

如果项目是基于antd组件库为基础搭建,可使用此公共筛选组件 使用到的库 npm i antdnpm i lodash-esnpm i @types/lodash-es -D /components/CommonSearch index.tsx import React from 'react';import { Button, Card, Form } from 'antd'

《offer来了》第二章学习笔记

1.集合 Java四种集合:List、Queue、Set和Map 1.1.List:可重复 有序的Collection ArrayList: 基于数组实现,增删慢,查询快,线程不安全 Vector: 基于数组实现,增删慢,查询快,线程安全 LinkedList: 基于双向链实现,增删快,查询慢,线程不安全 1.2.Queue:队列 ArrayBlockingQueue: