机器学习模型五花八门不知道怎么选?这份指南告诉你

2024-04-13 22:48

本文主要是介绍机器学习模型五花八门不知道怎么选?这份指南告诉你,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!


来源:授权自AI科技大本营(ID:rgznai100)

本文约4900字,建议阅读10分钟。

本文我们将探讨不同的机器学习模型,以及每个模型合理的使用场景。 



[ 导读 ] 一般来说,基于树形结构的模型在Kaggle竞赛中是表现最好的,而其它的模型可以用于融合模型。对于计算机视觉领域的挑战,CNNs (Convolutional Neural Network, 卷积神经网络)是最适合不过的。而对于NLP(Natural Language Processing,自然语言处理),LSTMs或GRUs是最好的选择。下面是一个不完全模型细目清单,同时列出了每个模型的一些优缺点。

1. 回归 — 预测连续值

A. 线性回归(Linear Regression)

 

  • I.Vanilla Linear Regressio

  • 优点

  • 善于获取数据集中的线性关系;

  • 适用于在已有了一些预先定义好的变量并且需要一个简单的预测模型的情况下使用;

  • 训练速度和预测速度较快;

  • 在小数据集上表现很好;

  • 结果可解释,并且易于说明;

  • 当新增数据时,易于更新模型;

  • 不需要进行参数调整(下面的正则化线性模型需要调整正则化参数);

  • 不需要特征缩放(下面的正则化线性模型需要特征缩放);

  • 如果数据集具有冗余的特征,那么线性回归可能是不稳定的;

  • 缺点

  • 不适用于非线性数据;

  • 预测精确度较低;

  • 可能会出现过度拟合(下面的正则化模型可以抵消这个影响);

  • 分离信号和噪声的效果不理想,在使用前需要去掉不相关的特征;

  • 不了解数据集中的特征交互;

  •  II. Lasso回归, Ridge回归, Elastic-Net回归

  • 优点

  • 这些模型是正则化的线性回归模型;

  • 有助于防止过度拟合;

  • 这些模型更善于正则化,因为它们更简单;

  • 适用于当我们只关心几个特征的时候;

  • 缺点

  • 需要特征缩放;

  • 需要调整正则化参数;

B. 回归树(Regression Trees)

  •   I.决策树(Decision Tree)

  • 优点

  • 训练速度和预测速度较快;

  • 善于获取数据集中的非线性关系;

  • 了解数据集中的特征交互;

  • 善于处理数据集中出现的异常值;

  • 善于在数据集中找到最重要的特征;

  • 不需要特征缩放;

  • 结果可解释,并易于说明;

  • 缺点

  • 预测精确度较低;

  • 需要一些参数的调整;

  • 不适用于小型数据集;

  • 分离信号和噪声的效果不理想;

  • 当新增数据时,不易更新模型;

  • 在实践中很少使用,而是更多地使用集合树;

  • 可能会出现过度拟合(见下面的融合模型);

  • II.融合模型(RandomForest,XGBoost, CatBoost, LightGBM)

  • 优点

  • 多重树结构整理预测;

  • 具有较高的预测精确度,在实践中表现很好;

  • 是Kaggle竞赛中推荐的算法;

  • 善于处理数据集中出现的异常值;

  • 善于在数据集中获取非线性关系;

  • 善于在数据集中找到最重要的特征;

  • 能够分离信号和噪声;

  • 不需要特征缩放;

  • 特别适用于高维度的数据;

  • 缺点

  • 训练速度较慢;

  • 具有较高的预测速度;

  • 结果不易解释或说明;

  • 当新增数据时,不易更新模型;

  • 需要调整参数,但调整较为复杂;

  • 不适用于小型数据集;

C. 深度学习(Deep Learning)

  • 优点

  • 在实践中表现出较高的预测精确度;

  • 可以获取数据中非常复杂的底层模式;

  • 特别适用于大型数据集和高维度数据集;

  • 当新增数据时,易于更新模型;

  • 网络的隐藏层明显减少了对特征工程的需求;

  • 是适用于计算机视觉、机器翻译、情感分析和语音识别任务的最新技术;

  • 缺点

  • 具有非常低的训练速度;

  • 需要消耗巨大的计算资源;

  • 需要特征缩放;

  • 结果不易解释或说明;

  • 需要大量的训练数据,因为它要学习大量的参数;

  • 在非图像、非文本、非语音的任务中优于Boosting算法;

  • 非常灵活,带有许多不同的体系结构构建块,因此需要专业知识来设计体系结构;

D. 基于距离的K近邻算法(K Nearest Neighbors – Distance Based)

  • 优点

  • 训练速度较快;

  • 不需要太多的参数调整;

  • 结果可解释,并易于说明;

  • 适用于小型数据集(小于10万个训练集)

  • 缺点

  • 预测精确度较低;

  • 不适用于小型数据集;

  • 需要选择合适的距离函数;

  • 需要特征缩放;

  • 预测速度随数据集增大而加快;

  • 分离信号和噪声的效果不理想,在使用前需要去掉不相关的特征;

  • 是内存密集型的算法,因为它可以保存所有的观察结果;

  • 不适用于处理高维度的数据;

2. 分类 — 预测一个或多个类别的概率

A. 逻辑回归算法(Logistic Regression)

  • 优点

  • 善于对线性可分离数据进行分类;

  • 具有较高的训练速度和预测速度;

  • 适用于小型数据集;

  • 结果可解释,并易于说明;

  • 当新增数据时,易于更新模型;

  • 在正则化时可以避免过度拟合;

  • 可以同时进行2个类和多个类的分类任务;

  • 不需要参数调整(除非在正则化的时候,我们需要调整正则化参数);

  • 不需要特征缩放(正则化的时候除外);

  • 如果数据集具有冗余特征,则线性回归可能是不稳定的;

  • 缺点

  • 不适用于非线性可分离数据;

  • 具有较低的预测精确度;

  • 可能会出现过度拟合(见下面的正则化模型)

  • 分离信号和噪声的效果不理想,在使用前需要去掉不相关的特征;

  • 不了解数据集中的特征交互;

B. 基于距离的支持向量机算法(Support Vector Machines – Distance based)

  • 优点

  • 具有较高的预测精确度;

  • 即使在高维度数据集上也不会产生过度拟合,因此它适用于具有多个特征的情况;

  • 适用于小型数据集(小于10万个训练集);

  • 适用于解决文本分类的问题;

  • 缺点

  • 当新增数据时,不易更新模型;

  • 属于内存高度密集型算法;

  • 不适用于大型数据集;

  • 需要选择正确的内核;

  • 线性内核对线性数据建模,运行速度快;

  • 非线性内核可以模拟非线性边界,运行速度慢;

  • 用Boosting代替!

C. 基于概率的朴素贝叶斯算法(Naive Bayes — Probability based)

  • 优点

  • 在文本分类问题上表现极佳;

  • 具有较高的训练速度和预测速度;

  • 在小型数据集上表现良好;

  • 善于分离信号和噪声;

  • 在实践中表现出良好的性能;

  • 操作简单,易于实现;

  • 适用于小型数据集(小于10万个训练集);

  • 关于特征的和潜在分布的独立性避免了过度拟合;

  • 如果这种独立性的条件成立,那么朴素贝叶斯可以在更小的数据集上运行,并且可以以更快的速度进行训练;

  • 不需要特征缩放;

  • 不是内存密集型算法;

  • 结果可解释,并易于说明;

  • 根据数据集的大小易于扩展;

  • 缺点

  • 具有较低的预测精确度;

D. 基于距离的K近邻算法( K Nearest Neighbors — Distance Based)

  • 优点

  • 具有较高的训练速度;

  • 无需太多参数调整;

  • 结果可解释,并易于说明;

  • 适用于小型数据集(小于10万个训练集);

  • 缺点

  • 预测精确度较低;

  • 在小型数据集上表现不好;

  • 需要选择一个合适的距离函数;

  • 需要功能缩放;

  • 预测速度随着数据集增大而加快;

  • 分离信号和噪声的效果不理想,在使用前需要去掉不相关的特征;

  • 是内存密集型算法,因为它可以保存所有的观察结果;

  • 不善于处理高维度的数据;

E. 分类树(Classification Tree)

  • I. 决策树(Decision Tree)

  • 优点

  • 具有较高的训练速度和预测速度;

  • 善于获取数据集中的非线性关系;

  • 了解数据集中的特征交互;

  • 善于处理数据集中出现的异常值;

  • 善于在数据集中找到最重要的特征;

  • 可以同时进行2个类和多个类的分类任务;

  • 不需要特征缩放;

  • 结果可解释,并易于说明;

  • 缺点

  • 预测速度较慢;

  • 需要进行参数的调整;

  • 在小型数据集上表现不好;

  • 分离信号和噪声的效果不理想;

  • 在实践中很少使用,而是更多地使用集合树;

  • 当新增数据时,不易更新模型;

  • 可能会出现过度拟合(见下面的融合模型)

  • II.融合(RandomForest, XGBoost, CatBoost, LightGBM)

  • 优点

  • 多重树结构整理预测;

  • 具有较高的预测精确度,在实践中表现很好;

  • 是Kaggle竞赛中推荐的算法;

  • 善于获取数据集中的非线性关系;

  • 善于处理数据集中出现的异常值;

  • 善于在数据集中找到最重要的特征;

  • 能够分离信号和噪声;

  • 无需特征缩放;

  • 特别适用于高维度的数据;

  • 缺点

  • 训练速度较慢;

  • 预测速度较快;

  • 结果不易解释或说明;

  • 当新增数据时,不易更新模型;

  • 需要调整参数,但调整较为复杂;

  • 在小型数据集上表现不好;

F. 深度学习(Deep Learning)

  • 优点

  • 预测精确度较高,在实践中表现良好;

  • 可以获取数据中非常复杂的底层模式;

  • 适用于大型数据集和高维度数据集;

  • 当新增数据时,易于更新模型;

  • 网络的隐藏层明显减少了对特征工程的需求;

  • 是适用于计算机视觉、机器翻译、情感分析和语音识别任务的最新技术;

  • 缺点

  • 训练速度较慢;

  • 结果不易解释或说明;

  • 需要消耗巨大的计算资源;

  • 需要特征缩放;

  • 需要大量的训练数据,因为它要学习大量的参数;

  • 在非图像、非文本、非语音的任务中优于Boosting算法;

  • 非常灵活,带有许多不同的体系结构构建块,因此需要专业知识来设计体系结构;

3. 聚类 — 将数据分类以便最大化相似性

A. DBSCAN聚类算法(Density-Based Spatial Clustering of Applications with Noise)

  • 优点

  • 可扩展到大型数据集上;

  • 善于噪声检测;

  • 无需预先知道聚类的数量;

  • 可以发现任意形状的聚类,不会假设聚类的形状是球状的;

  • 缺点

  • 如果整个数据集都是高密度区域,那么该算法不总是有效的;

  • 需要调整密度参数epsilon和min_samples为正确的值,以便获得好的效果;

B. Kmeans算法

  • 优点

  • 特别适于获取底层数据集的结构;

  • 算法简单,易于解释;

  • 适于预先知道聚类的数量;

  • 缺点

  • 如果聚类不是球状的,并且大小相似,那么该算法不总是有效的;

  • 需要预先知道聚类的数量,并需要调整k聚类的选择以便获得好的结果;

  • 属于内存密集型的算法;

  • 无法扩展到大型数据集上;

4. Misc — 本文中未包含的模型

  • 降维算法(Dimensionality Reduction Algorithms);

  • 聚类算法(Clustering algorithms);

  • 高斯混合模型(Gaussian Mixture Model);

  • 分层聚类(Hierarchical clustering);

  • 计算机视觉(CV);

  • 卷积神经网络(Convolutional Neural Networks);

  • 图像分类(Image classification);

  • 对象检测(Object Detection)

  • 图像分割(Image segmentation)

  • 自然语言处理(Natural Language Processing,NLP)

  • 循环神经网络(Recurrent Neural Network,RNNs,包括LSTM 和 GRUs)

  • 强化学习(Reinforcement Learning)

融合模型

融合模型是一种非常强大的技术,有助于减少过度拟合,并通过组合来自不同模型的输出以做出更稳定的预测。融合模型是赢得Kaggle竞赛的一个重要工具,在选择模型进行融合时,我们希望选择不同类型的模型,以确保它们具有不同的优势和劣势,从而在数据集中获取不同的模式。这种更明显的多样性特点使得偏差降低。我们还希望确保它们的性能是可以对比的,这样就能确保预测的稳定性。

 

我们在这里可以看到,这些模型的融合实际上比任何单一的模型生成的损失都要低得多。部分的原因是,尽管所有的这些模型都非常擅长预测,但它们都能得到不同的正确预测结果,通过把它们组合在一起,我们能够根据它们所有不同的优势组合成一个超级模型。

               

# in order to make the final predictions more robust to overfittingdef blended_predictions(X):    return ((0.1 * ridge_model_full_data.predict(X)) + \\            (0.2 * svr_model_full_data.predict(X)) + \\            (0.1 * gbr_model_full_data.predict(X)) + \\            (0.1 * xgb_model_full_data.predict(X)) + \\            (0.1 * lgb_model_full_data.predict(X)) + \\            (0.05 * rf_model_full_data.predict(X)) + \\            (0.35 * stack_gen_model.predict(np.array(X))))

融合模型分为四种类型(包括混合型):

  • Bagging:使用随机选择的不同数据子集训练多个基础模型,并进行替换。让基础模型对最终的预测进行投票。常用于随机森林算法(RandomForests);

  • Boosting:迭代地训练模型,并且在每次迭代之后更新获得每个训练示例的重要程度。常用于梯度增强算法(GradientBoosting);

  • Blending:训练许多不同类型的基础模型,并在一个holdout set上进行预测。从它们的预测结果中再训练一个新的模型,并在测试集上进行预测(用一个holdout set堆叠);

  • Stacking:训练多种不同类型的基础模型,并对数据集的k-folds进行预测。从它们的预测结果中再训练一个新的模型,并在测试集上进行预测;

模型对比

权重和偏差让我们可以用一行代码来跟踪和比较模型的性能表现。选择要测试的模型后,对其进行训练并添加wandb.log({‘score’: cv_score})来记录模型的运行状态。完成训练之后,你就可以在一个简单的控制台中对比模型的性能了!

 

# WandBimport wandbimport tensorflow.kerasfrom wandb.keras import WandbCallbackfrom sklearn.model_selection import cross_val_score# Import models (Step 1: add your models here)from sklearn import svmfrom sklearn.linear_model import Ridge, RidgeCVfrom xgboost import XGBRegressor
# Model 1# Initialize wandb run# You can change your project name here. For more config options, see https://docs.wandb.com/docs/init.htmlwandb.init(anonymous='allow', project="pick-a-model")
# Initialize model (Step 2: add your classifier here)clf = svm.SVR(C= 20, epsilon= 0.008, gamma=0.0003)
# Get CV scorescv_scores = cross_val_score(clf, X_train, train_labels, cv=5)
# Log scoresfor cv_score in cv_scores:    wandb.log({'score': cv_score})
# Model 2# Initialize wandb run# You can change your project name here. For more config options, see https://docs.wandb.com/docs/init.htmlwandb.init(anonymous='allow', project="pick-a-model")
# Initialize model (Step 2: add your classifier here)clf = XGBRegressor(learning_rate=0.01,                       n_estimators=6000,                       max_depth=4,                       min_child_weight=0,                       gamma=0.6,                       subsample=0.7,                       colsample_bytree=0.7,                       objective='reg:linear',                       nthread=-1,                       scale_pos_weight=1,                       seed=27,                       reg_alpha=0.00006,                       random_state=42)
# Get CV scorescv_scores = cross_val_score(clf, X_train, train_labels, cv=5)
# Log scoresfor cv_score in cv_scores:    wandb.log({'score': cv_score})
# Model 3# Initialize wandb run# You can change your project name here. For more config options, see https://docs.wandb.com/docs/init.htmlwandb.init(anonymous='allow', project="pick-a-model")
# Initialize model (Step 2: add your classifier here)ridge_alphas = [1e-15, 1e-10, 1e-8, 9e-4, 7e-4, 5e-4, 3e-4, 1e-4, 1e-3, 5e-2, 1e-2, 0.1, 0.3, 1, 3, 5, 10, 15, 18, 20, 30, 50, 75, 100]clf = Ridge(alphas=ridge_alphas)
# Get CV scorescv_scores = cross_val_score(clf, X_train, train_labels, cv=5)
# Log scoresfor cv_score in cv_scores:    wandb.log({'score': cv_score})

就这样,在有了所有的工具和算法之后,就可以为你的问题选择正确的模型了!

 

模型的选择可能是非常复杂的,但我希望本指南能给你带来一些启发,让你找到模型选择的好方法。

 

原文链接:

https://lavanya.ai/2019/09/18/part-ii-whirlwind-tour-of-machine-learning-models/

编辑:于腾凯

校对:林亦霖

这篇关于机器学习模型五花八门不知道怎么选?这份指南告诉你的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/901433

相关文章

SpringBoot快速接入OpenAI大模型的方法(JDK8)

《SpringBoot快速接入OpenAI大模型的方法(JDK8)》本文介绍了如何使用AI4J快速接入OpenAI大模型,并展示了如何实现流式与非流式的输出,以及对函数调用的使用,AI4J支持JDK8... 目录使用AI4J快速接入OpenAI大模型介绍AI4J-github快速使用创建SpringBoot

Nginx实现动态封禁IP的步骤指南

《Nginx实现动态封禁IP的步骤指南》在日常的生产环境中,网站可能会遭遇恶意请求、DDoS攻击或其他有害的访问行为,为了应对这些情况,动态封禁IP是一项十分重要的安全策略,本篇博客将介绍如何通过NG... 目录1、简述2、实现方式3、使用 fail2ban 动态封禁3.1 安装 fail2ban3.2 配

Java中String字符串使用避坑指南

《Java中String字符串使用避坑指南》Java中的String字符串是我们日常编程中用得最多的类之一,看似简单的String使用,却隐藏着不少“坑”,如果不注意,可能会导致性能问题、意外的错误容... 目录8个避坑点如下:1. 字符串的不可变性:每次修改都创建新对象2. 使用 == 比较字符串,陷阱满

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

使用 sql-research-assistant进行 SQL 数据库研究的实战指南(代码实现演示)

《使用sql-research-assistant进行SQL数据库研究的实战指南(代码实现演示)》本文介绍了sql-research-assistant工具,该工具基于LangChain框架,集... 目录技术背景介绍核心原理解析代码实现演示安装和配置项目集成LangSmith 配置(可选)启动服务应用场景

0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型的操作流程

《0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeekR1模型的操作流程》DeepSeekR1模型凭借其强大的自然语言处理能力,在未来具有广阔的应用前景,有望在多个领域发... 目录0基础租个硬件玩deepseek,蓝耘元生代智算云|本地部署DeepSeek R1模型,3步搞定一个应

Deepseek R1模型本地化部署+API接口调用详细教程(释放AI生产力)

《DeepseekR1模型本地化部署+API接口调用详细教程(释放AI生产力)》本文介绍了本地部署DeepSeekR1模型和通过API调用将其集成到VSCode中的过程,作者详细步骤展示了如何下载和... 目录前言一、deepseek R1模型与chatGPT o1系列模型对比二、本地部署步骤1.安装oll

Java深度学习库DJL实现Python的NumPy方式

《Java深度学习库DJL实现Python的NumPy方式》本文介绍了DJL库的背景和基本功能,包括NDArray的创建、数学运算、数据获取和设置等,同时,还展示了如何使用NDArray进行数据预处理... 目录1 NDArray 的背景介绍1.1 架构2 JavaDJL使用2.1 安装DJL2.2 基本操

Spring AI Alibaba接入大模型时的依赖问题小结

《SpringAIAlibaba接入大模型时的依赖问题小结》文章介绍了如何在pom.xml文件中配置SpringAIAlibaba依赖,并提供了一个示例pom.xml文件,同时,建议将Maven仓... 目录(一)pom.XML文件:(二)application.yml配置文件(一)pom.xml文件:首

MySql死锁怎么排查的方法实现

《MySql死锁怎么排查的方法实现》本文主要介绍了MySql死锁怎么排查的方法实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录前言一、死锁排查方法1. 查看死锁日志方法 1:启用死锁日志输出方法 2:检查 mysql 错误