机器学习必修课 - 交叉验证 Cross-Validation

2023-10-10 06:04

本文主要是介绍机器学习必修课 - 交叉验证 Cross-Validation,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

想象一下你有一个包含5000行数据的数据集。通常情况下,你会将约20%的数据保留作为验证数据集,即1000行。但这会在确定模型得分时引入一些随机性。也就是说,一个模型可能在一组1000行数据上表现良好,即使在另一组1000行数据上表现不准确。

运行环境:Google Colab

数据准备和预处理

!git clone https://github.com/JeffereyWu/Housing-prices-data.git
import pandas as pd
from sklearn.model_selection import train_test_split# Read the data
train_data = pd.read_csv('/content/Housing-prices-data/train.csv', index_col='Id')
test_data = pd.read_csv('/content/Housing-prices-data/test.csv', index_col='Id')
# Remove rows with missing target, separate target from predictors
train_data.dropna(axis=0, subset=['SalePrice'], inplace=True)
y = train_data.SalePrice              
train_data.drop(['SalePrice'], axis=1, inplace=True)
  • 删除训练数据中带有缺失目标值的行,并将目标值(SalePrice)分离出来存储在变量y中。
# Select numeric columns only
numeric_cols = [cname for cname in train_data.columns if train_data[cname].dtype in ['int64', 'float64']]
X = train_data[numeric_cols].copy()
X_test = test_data[numeric_cols].copy()
  • 从训练数据中选择了仅包含数值型数据的列,存储在变量numeric_cols中。
  • 创建了训练特征数据集X,其中包含了数值型的特征列。
  • 创建了测试特征数据集X_test,也包含了数值型的特征列。
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputermy_pipeline = Pipeline(steps=[('preprocessor', SimpleImputer()),('model', RandomForestRegressor(n_estimators=50, random_state=0))
])
  • 上面的管道将使用SimpleImputer()来替换数据中的缺失值,然后使用RandomForestRegressor()来训练一个随机森林模型进行预测。我们可以通过n_estimators参数来设置随机森林模型中树的数量,并通过设置random_state参数来确保结果的可重复性。

在交叉验证中,我们对数据的不同子集运行我们的建模过程,以获取模型质量的多个度量值。

例如,我们可以将数据分成5个部分,每个部分占全数据集的20%。在这种情况下,我们称将数据分成了5个“折叠”(fold)。

然后,我们对每个折叠运行一次实验:
在这里插入图片描述

  1. 在实验1中,我们将第一个折叠作为验证(或保留)集,将其他所有部分作为训练数据。这样可以基于一个20%的保留集来度量模型的质量。
  2. 在实验2中,我们保留第二个折叠的数据(并使用除第二个折叠之外的所有数据来训练模型)。然后,保留集用于获取模型质量的第二个估计值。

我们重复这个过程,每个折叠都曾被用作保留集。综合起来,100%的数据都会在某个时刻被用作保留集,最终我们会得到一个基于数据集中所有行的模型质量度量(即使我们不同时使用所有行)。

from sklearn.model_selection import cross_val_score# Multiply by -1 since sklearn calculates *negative* MAE
scores = -1 * cross_val_score(my_pipeline, X, y,cv=5,scoring='neg_mean_absolute_error')print("Average MAE score:", scores.mean())
  • 使用cross_val_score()函数来获取平均绝对误差(MAE),该值是在五个不同的折叠上求平均得到的。我们通过cv参数来设置折叠的数量。

Average MAE score: 18276.410356164386

定义一个函数用于评估随机森林回归模型在不同树的数量(n_estimators)下的性能。

def get_score(n_estimators):my_pipeline = Pipeline(steps=[('preprocessor', SimpleImputer()),('model', RandomForestRegressor(n_estimators, random_state=0))])scores = -1 * cross_val_score(my_pipeline, X, y,cv=3,scoring='neg_mean_absolute_error')return scores.mean()
results = {i: get_score(i) for i in range(50, 450, 50)}
  • 评估随机森林模型在八个不同的树数量下的性能:50、100、150、…、300、350、400。将结果存储在一个字典results中,其中results[i]表示get_score(i)返回的平均MAE。
import matplotlib.pyplot as plt
%matplotlib inlineplt.plot(list(results.keys()), list(results.values()))
plt.show()

在这里插入图片描述
由此可见,n_estimators设为200时,可得到最佳的随机森林模型。

这篇关于机器学习必修课 - 交叉验证 Cross-Validation的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/178471

相关文章

MySQL 主从复制部署及验证(示例详解)

《MySQL主从复制部署及验证(示例详解)》本文介绍MySQL主从复制部署步骤及学校管理数据库创建脚本,包含表结构设计、示例数据插入和查询语句,用于验证主从同步功能,感兴趣的朋友一起看看吧... 目录mysql 主从复制部署指南部署步骤1.环境准备2. 主服务器配置3. 创建复制用户4. 获取主服务器状态5

Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式

《Java通过驱动包(jar包)连接MySQL数据库的步骤总结及验证方式》本文详细介绍如何使用Java通过JDBC连接MySQL数据库,包括下载驱动、配置Eclipse环境、检测数据库连接等关键步骤,... 目录一、下载驱动包二、放jar包三、检测数据库连接JavaJava 如何使用 JDBC 连接 mys

Spring Security中用户名和密码的验证完整流程

《SpringSecurity中用户名和密码的验证完整流程》本文给大家介绍SpringSecurity中用户名和密码的验证完整流程,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定... 首先创建了一个UsernamePasswordAuthenticationTChina编程oken对象,这是S

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示

Spring Validation中9个数据校验工具使用指南

《SpringValidation中9个数据校验工具使用指南》SpringValidation作为Spring生态系统的重要组成部分,提供了一套强大而灵活的数据校验机制,本文给大家介绍了Spring... 目录1. Bean Validation基础注解常用注解示例在控制器中应用2. 自定义约束验证器定义自

Android NDK版本迭代与FFmpeg交叉编译完全指南

《AndroidNDK版本迭代与FFmpeg交叉编译完全指南》在Android开发中,使用NDK进行原生代码开发是一项常见需求,特别是当我们需要集成FFmpeg这样的多媒体处理库时,本文将深入分析A... 目录一、android NDK版本迭代分界线二、FFmpeg交叉编译关键注意事项三、完整编译脚本示例四

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效