回归算法详解

2024-06-20 09:12
文章标签 详解 算法 回归

本文主要是介绍回归算法详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

回归算法详解

回归分析是一类重要的机器学习方法,主要用于预测连续变量。本文将详细讲解几种常见的回归算法,包括线性回归、岭回归、Lasso 回归、弹性网络回归、决策树回归和支持向量回归(SVR),并展示它们的特点、应用场景及其在 Python 中的实现。

一 什么是回归分析?

回归分析是一种统计方法,用于确定因变量(目标变量)和自变量(预测变量)之间的关系。回归分析的目标是建立一个模型,通过自变量预测因变量。

二 常见回归算法

1. 线性回归

线性回归是最基本的回归方法,假设因变量和自变量之间存在线性关系。线性回归的目标是找到一条直线,使得所有数据点到该直线的距离之和最小。
线性回归模型的方程为:
y = β 0 + β 1 x 1 + β 2 x 2 + … + β n x n + ϵ \ y = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \ldots + \beta_n x_n + \epsilon \  y=β0+β1x1+β2x2++βnxn+ϵ 
其中, β 0 \ \beta_0  β0 是截距, β i \ \beta_i  βi 是自变量 x i \ x_i  xi 的回归系数, ϵ \ \epsilon  ϵ 是误差项。

损失函数

最小化均方误差(Mean Squared Error, MSE):
MSE = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \ \text{MSE} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 \  MSE=n1i=1n(yiy^i)2 

特点
  • 简单易懂:线性回归是最简单的回归模型,易于解释和实现。
  • 计算速度快:适用于大规模数据集。
  • 易于扩展:可以通过添加多项式项、交互项等扩展为更复杂的模型。
应用场景

线性回归适用于因变量和自变量之间存在线性关系的场景,例如经济学中的供求关系、工程中的温度与压力关系等。

2. 岭回归

岭回归(Ridge Regression)是一种线性回归的变种,通过在损失函数中加入 L 2 \ L2  L2 正则化项来防止过拟合。

岭回归的损失函数为:
Loss = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ ∑ j = 1 n β j 2 \ \text{Loss} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda \sum_{j=1}^n \beta_j^2 \  Loss=n1i=1n(yiy^i)2+λj=1nβj2 
其中, λ \lambda λ是正则化参数。

特点
  • 防止过拟合:通过正则化项控制模型的复杂度,防止过拟合。
  • 处理共线性:适用于自变量之间存在较强相关性的情况。
  • 参数选择:需要调优正则化参数 λ \lambda λ
应用场景

岭回归适用于高维数据集和自变量之间存在共线性的场景,如基因表达数据分析、文本数据分类等。

3. Lasso 回归

Lasso 回归(Least Absolute Shrinkage and Selection Operator Regression)通过加入 L 1 L1 L1 正则化项来防止过拟合,并能够进行特征选择。

Lasso 回归的损失函数为:
Loss = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ ∑ j = 1 n ∣ β j ∣ \ \text{Loss} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda \sum_{j=1}^n |\beta_j| \  Loss=n1i=1n(yiy^i)2+λj=1nβj 

特点
  • 特征选择:通过 L 1 L1 L1 正则化,将不重要的特征系数收缩为零,从而实现特征选择。
  • 防止过拟合:与岭回归类似,Lasso 回归也能够控制模型的复杂度。
  • 参数选择:需要调优正则化参数 λ \lambda λ
应用场景

Lasso 回归适用于高维数据和特征选择的场景,如基因数据分析、文本分类、图像处理等。

4. 弹性网络回归

弹性网络回归(Elastic Net Regression)结合了岭回归和 Lasso 回归的正则化项,能够同时进行特征选择和防止过拟合。

弹性网络回归的损失函数为:
Loss = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 + λ 1 ∑ j = 1 n ∣ β j ∣ + λ 2 ∑ j = 1 n β j 2 \ \text{Loss} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 + \lambda_1 \sum_{j=1}^n |\beta_j| + \lambda_2 \sum_{j=1}^n \beta_j^2 \  Loss=n1i=1n(yiy^i)2+λ1j=1nβj+λ2j=1nβj2 

特点
  • 特征选择与防止过拟合:结合了 Lasso 和岭回归的优点,能够进行特征选择并防止过拟合。
  • 适用于高维数据:在高维数据集上表现良好。
  • 参数选择:需要调优两个正则化参数 λ 1 \lambda_1 λ1 λ 2 \lambda_2 λ2
应用场景

弹性网络回归适用于高维数据和特征选择的场景,尤其是当自变量之间存在高度相关性时。

5. 决策树回归

决策树回归通过构建决策树来进行回归,能够捕捉非线性关系。

算法步骤
  1. 将数据集划分为若干子集。
  2. 对每个子集,选择一个特征及其取值进行划分,使得划分后的均方误差最小。
  3. 递归地对每个子集进行上述划分,直到满足停止条件。
特点
  • 捕捉非线性关系:决策树能够处理非线性和交互效应。
  • 易于解释:决策树的结构直观易懂,便于解释。
  • 易受过拟合影响:需要剪枝等技术来防止过拟合。
应用场景

决策树回归适用于数据集特征和目标变量之间存在非线性关系的场景,如市场预测、医学诊断等。

6. 支持向量回归(SVR)

支持向量回归(Support Vector Regression, SVR)是支持向量机的扩展,用于回归分析。SVR 寻找一个函数,使得大多数数据点都在一个容忍范围内。

SVR 的优化目标为:
min ⁡ w , b , ξ , ξ ∗ 1 2 ∥ w ∥ 2 + C ∑ i = 1 n ( ξ i + ξ i ∗ ) \ \min_{w,b,\xi,\xi^*} \frac{1}{2} \|w\|^2 + C \sum_{i=1}^n (\xi_i + \xi_i^*) \  w,b,ξ,ξmin21w2+Ci=1n(ξi+ξi) 
约束条件为:
y i − ( w ⋅ x i + b ) ≤ ϵ + ξ i , ( w ⋅ x i + b ) − y i ≤ ϵ + ξ i ∗ ; ξ i , ξ i ∗ ≥ 0 \ y_i - (w \cdot x_i + b) \leq \epsilon + \xi_i \ , \ (w \cdot x_i + b) - y_i \leq \epsilon + \xi_i^* \ ; \ \xi_i, \xi_i^* \geq 0  yi(wxi+b)ϵ+ξi , (wxi+b)yiϵ+ξi ; ξi,ξi0
其中, ϵ \ \epsilon  ϵ 是容忍范围, ξ i \ \xi_i  ξi ξ i ∗ \ \xi_i^*  ξi是松弛变量, C C C 是惩罚参数。

特点
  • 处理非线性关系:通过核方法,SVR 能够处理复杂的非线性关系。
  • 鲁棒性强:对噪声和异常值具有较强的鲁棒性。
  • 参数选择复杂:需要调优核函数和惩罚参数 C C C
应用场景

SVR 适用于处理复杂非线性数据的场景,如股票价格预测、能源消耗预测等。

三 回归算法的 Python 实现

下面通过 Python 代码实现上述回归算法,并以一个示例数据集展示其应用。

导入库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.tree import DecisionTreeRegressor
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_errorplt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

生成示例数据集

# 生成示例数据集
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)plt.scatter(X, y, color='blue')
plt.title('Sample Data')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

线性回归

# 线性回归
linear_reg = LinearRegression()
linear_reg.fit(X_train, y_train)
y_pred = linear_reg.predict(X_test)print('Linear Regression MSE:', mean_squared_error(y_test, y_pred))plt.scatter(X_test, y_test, color='blue')
plt.plot(X_test, y_pred, color='red')
plt.title('Linear Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

岭回归

# 岭回归
ridge_reg = Ridge(alpha=1.0)
ridge_reg.fit(X_train, y_train)
y_pred = ridge_reg.predict(X_test)print('Ridge Regression MSE:', mean_squared_error(y_test, y_pred))plt.scatter(X_test, y_test, color='blue')
plt.plot(X_test, y_pred, color='red')
plt.title('Ridge Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

Lasso 回归

# Lasso 回归
lasso_reg = Lasso(alpha=0.1)
lasso_reg.fit(X_train, y_train)
y_pred = lasso_reg.predict(X_test)print('Lasso Regression MSE:', mean_squared_error(y_test, y_pred))plt.scatter(X_test, y_test, color='blue')
plt.plot(X_test, y_pred, color='red')
plt.title('Lasso Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

弹性网络回归

# 弹性网络回归
elastic_net_reg = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic_net_reg.fit(X_train, y_train)
y_pred = elastic_net_reg.predict(X_test)print('Elastic Net Regression MSE:', mean_squared_error(y_test, y_pred))plt.scatter(X_test, y_test, color='blue')
plt.plot(X_test, y_pred, color='red')
plt.title('Elastic Net Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

决策树回归

# 决策树回归
tree_reg = DecisionTreeRegressor()
tree_reg.fit(X_train, y_train)
y_pred = tree_reg.predict(X_test)print('Decision Tree Regression MSE:', mean_squared_error(y_test, y_pred))plt.scatter(X_test, y_test, color='blue')
plt.scatter(X_test, y_pred, color='red')
plt.title('Decision Tree Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

支持向量回归(SVR)

# 支持向量回归
svr_reg = SVR(kernel='rbf', C=100, epsilon=0.1)
svr_reg.fit(X_train, y_train)
y_pred = svr_reg.predict(X_test)print('SVR MSE:', mean_squared_error(y_test, y_pred))plt.scatter(X_test, y_test, color='blue')
plt.scatter(X_test, y_pred, color='red')
plt.title('Support Vector Regression')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.show()

在这里插入图片描述

总结

回归分析是机器学习中的一类重要方法,用于预测连续变量。本文介绍了几种常见的回归算法,包括线性回归、岭回归、Lasso 回归、弹性网络回归、决策树回归和支持向量回归,并展示了它们的数学公式、特点、应用场景及其在 Python 中的实现。不同的回归算法适用于不同的应用场景,通过合理选择算法,可以在实际应用中取得良好的预测效果。希望本文能帮助你更好地理解和应用回归算法。

这篇关于回归算法详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1077708

相关文章

十四、观察者模式与访问者模式详解

21.观察者模式 21.1.课程目标 1、 掌握观察者模式和访问者模式的应用场景。 2、 掌握观察者模式在具体业务场景中的应用。 3、 了解访问者模式的双分派。 4、 观察者模式和访问者模式的优、缺点。 21.2.内容定位 1、 有 Swing开发经验的人群更容易理解观察者模式。 2、 访问者模式被称为最复杂的设计模式。 21.3.观察者模式 观 察 者 模 式 ( Obser

【操作系统】信号Signal超详解|捕捉函数

🔥博客主页: 我要成为C++领域大神🎥系列专栏:【C++核心编程】 【计算机网络】 【Linux编程】 【操作系统】 ❤️感谢大家点赞👍收藏⭐评论✍️ 本博客致力于知识分享,与更多的人进行学习交流 ​ 如何触发信号 信号是Linux下的经典技术,一般操作系统利用信号杀死违规进程,典型进程干预手段,信号除了杀死进程外也可以挂起进程 kill -l 查看系统支持的信号

Jitter Injection详解

一、定义与作用 Jitter Injection,即抖动注入,是一种在通信系统中人为地添加抖动的技术。该技术通过在发送端对数据包进行延迟和抖动调整,以实现对整个通信系统的时延和抖动的控制。其主要作用包括: 改善传输质量:通过调整数据包的时延和抖动,可以有效地降低误码率,提高数据传输的可靠性。均衡网络负载:通过对不同的数据流进行不同程度的抖动注入,可以实现网络资源的合理分配,提高整体传输效率。增

代码随想录算法训练营:12/60

非科班学习算法day12 | LeetCode150:逆波兰表达式 ,Leetcode239: 滑动窗口最大值  目录 介绍 一、基础概念补充: 1.c++字符串转为数字 1. std::stoi, std::stol, std::stoll, std::stoul, std::stoull(最常用) 2. std::stringstream 3. std::atoi, std

人工智能机器学习算法总结神经网络算法(前向及反向传播)

1.定义,意义和优缺点 定义: 神经网络算法是一种模仿人类大脑神经元之间连接方式的机器学习算法。通过多层神经元的组合和激活函数的非线性转换,神经网络能够学习数据的特征和模式,实现对复杂数据的建模和预测。(我们可以借助人类的神经元模型来更好的帮助我们理解该算法的本质,不过这里需要说明的是,虽然名字是神经网络,并且结构等等也是借鉴了神经网络,但其原型以及算法本质上还和生物层面的神经网络运行原理存在

Steam邮件推送内容有哪些?配置教程详解!

Steam邮件推送功能是否安全?如何个性化邮件推送内容? Steam作为全球最大的数字游戏分发平台之一,不仅提供了海量的游戏资源,还通过邮件推送为用户提供最新的游戏信息、促销活动和个性化推荐。AokSend将详细介绍Steam邮件推送的主要内容。 Steam邮件推送:促销优惠 每当平台举办大型促销活动,如夏季促销、冬季促销、黑色星期五等,用户都会收到邮件通知。这些邮件详细列出了打折游戏、

探索Elastic Search:强大的开源搜索引擎,详解及使用

🎬 鸽芷咕:个人主页  🔥 个人专栏: 《C++干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引入 全文搜索属于最常见的需求,开源的 Elasticsearch (以下简称 Elastic)是目前全文搜索引擎的首选,相信大家多多少少的都听说过它。它可以快速地储存、搜索和分析海量数据。就连维基百科、Stack Overflow、

大林 PID 算法

Dahlin PID算法是一种用于控制和调节系统的比例积分延迟算法。以下是一个简单的C语言实现示例: #include <stdio.h>// DALIN PID 结构体定义typedef struct {float SetPoint; // 设定点float Proportion; // 比例float Integral; // 积分float Derivative; // 微分flo

常用MQ消息中间件Kafka、ZeroMQ和RabbitMQ对比及RabbitMQ详解

1、概述   在现代的分布式系统和实时数据处理领域,消息中间件扮演着关键的角色,用于解决应用程序之间的通信和数据传递的挑战。在众多的消息中间件解决方案中,Kafka、ZeroMQ和RabbitMQ 是备受关注和广泛应用的代表性系统。它们各自具有独特的特点和优势,适用于不同的应用场景和需求。   Kafka 是一个高性能、可扩展的分布式消息队列系统,被设计用于处理大规模的数据流和实时数据传输。它

Linux中拷贝 cp命令中拷贝所有的写法详解

This text from: http://www.jb51.net/article/101641.htm 一、预备  cp就是拷贝,最简单的使用方式就是: cp oldfile newfile 但这样只能拷贝文件,不能拷贝目录,所以通常用: cp -r old/ new/ 那就会把old目录整个拷贝到new目录下。注意,不是把old目录里面的文件拷贝到new目录,