scikit-learn中常见的train test split

2024-08-21 10:58

本文主要是介绍scikit-learn中常见的train test split,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. train_test_split

进行一次性划分

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)
"""X: array([[0, 1],[2, 3],[4, 5],[6, 7],[8, 9]])
list(y): [0, 1, 2, 3, 4]
"""X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)""">>> X_trainarray([[4, 5],[0, 1],[6, 7]])>>> y_train[2, 0, 3]>>> X_testarray([[2, 3],[8, 9]])>>> y_test[1, 4]
"""
train_test_split(y, shuffle=False)[[0, 1, 2], [3, 4]]
  • X, y: 可为lists, numpy arrays, scipy-sparse, matrices或者dataframes

2. ShuffleSplit

sklearn.model_selection.ShuffleSplit用来将数据集分为测试集和验证集,可以多次划分

from sklearn.model_selection import ShuffleSplit
import numpy as npX, y = np.arange(20).reshape((10, 2)), range(10)ss = ShuffleSplit(n_splits=10, test_size=0.2, train_size=None, random_state=None)for train_indices, test_indices in ss.split(sample):print(f"train_indices: {train_indices}, test_indices: {test_indices}")

输出:

train_indices: [4 3 0 6 8 1 9 2], test_indices: [7 5]
train_indices: [0 5 3 4 2 6 9 8], test_indices: [1 7]
train_indices: [2 0 4 1 7 6 3 9], test_indices: [5 8]
train_indices: [2 6 9 8 5 3 4 1], test_indices: [0 7]
train_indices: [0 8 7 9 4 5 2 1], test_indices: [6 3]
train_indices: [6 5 2 8 1 0 3 4], test_indices: [9 7]
train_indices: [8 4 9 5 0 3 2 6], test_indices: [1 7]
train_indices: [6 5 2 1 4 3 0 7], test_indices: [8 9]
train_indices: [8 9 1 7 4 6 5 3], test_indices: [0 2]
train_indices: [1 3 9 5 0 2 7 6], test_indices: [4 8]
  • n_splits: int, 划分训练集、测试集的次数,默认为10
  • test_size: float, int, None, default=0.1; 测试集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示测试集占总样本的比例;该值为整型值时,表示具体的测试集样本数量。
    方法:
  • get_n_splits:获取分割次数
  • split(X, y=None, groups=None): 进行分割,并返回索引

3. GroupShuffleSplit

ShuffleSplit基本相同,区别在于先进行分组,然后按照分组来进行训练集和验证集划分。

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupShuffleSplit
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],'classname':['c5','c0','c1','c5','c0','c0','c1','c1','c2','c6'],'img':['img_41179.jpg','img_50749.jpg','img_53609.jpg','img_52213.jpg','img_72495.jpg', 'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})gss = GroupShuffleSplit(n_splits=4, test_size=0.25, random_state=0)tmp_groups = sample.loc[:, 'subject'].values# 进行一次划分
train_idxs, test_idxs =next(gss.split(X=sample['img'], y=sample['classname'], groups=tmp_groups))# 进行多次划分
for train_indices, test_indices in gss.split(sample.loc[:, "img"], sample.loc[:, "classname"], groups=tmp_groups):print(f"\ntrain_indices: {train_indices}, test_indices: {test_indices}")print(f"train subjects: {sample.loc[train_indices, 'subject']}, test subjects: {sample.loc[test_indices, 'subject']}")

输出:

fold====0=====
train_indices: [0 1 2 3 4 9], test_indices: [5 6 7 8]
train subjects: 0    p012
1    p012
2    p014
3    p014
4    p014
9    p081
Name: subject, dtype: object, test subjects: 5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: objectfold====1=====
train_indices: [2 3 4 5 6 7 8 9], test_indices: [0 1]
train subjects: 2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object, test subjects: 0    p012
1    p012
Name: subject, dtype: objectfold====2=====
train_indices: [0 1 2 3 4 5 6 7 8], test_indices: [9]
train subjects: 0    p012
1    p012
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: object, test subjects: 9    p081
Name: subject, dtype: objectfold====3=====
train_indices: [0 1 5 6 7 8 9], test_indices: [2 3 4]
train subjects: 0    p012
1    p012
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object, test subjects: 2    p014
3    p014
4    p014
Name: subject, dtype: object

可以看出已经进行了分组之后再进行划分。

4. GroupKFold

GroupKFoldGroupShuffleSplit基本相同,区别是GroupShuffleSplit每次都是做独立的划分,不同的划分之家可能会重叠。而GroupKFold则没有重叠,因此没有所谓的test_sizerandom_state参数。

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupKFold
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],'classname':['c5','c0','c1','c5','c0','c0','c1','c1','c2','c6'],'img':['img_41179.jpg','img_50749.jpg','img_53609.jpg','img_52213.jpg','img_72495.jpg', 'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})gkf = GroupKFold(n_splits=4)tmp_groups = sample.loc[:, 'subject'].values# 进行一次划分
train_idxs, test_idxs =next(gkf.split(X=sample['img'], y=sample['classname'], groups=tmp_groups))# 进行多次划分
for train_indices, test_indices in gkf.split(sample.loc[:, "img"], sample.loc[:, "classname"], groups=tmp_groups):print(f"\ntrain_indices: {train_indices}, test_indices: {test_indices}")print(f"train subjects: \n{sample.loc[train_indices, 'subject']}, \ntest subjects: \n{sample.loc[test_indices, 'subject']}")

输出:

train_indices: [0 1 2 3 4 9], test_indices: [5 6 7 8]
train subjects:
0    p012
1    p012
2    p014
3    p014
4    p014
9    p081
Name: subject, dtype: object,
test subjects:
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: objecttrain_indices: [0 1 5 6 7 8 9], test_indices: [2 3 4]
train subjects:
0    p012
1    p012
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object,
test subjects:
2    p014
3    p014
4    p014
Name: subject, dtype: objecttrain_indices: [2 3 4 5 6 7 8 9], test_indices: [0 1]
train subjects:
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object,
test subjects:
0    p012
1    p012
Name: subject, dtype: objecttrain_indices: [0 1 2 3 4 5 6 7 8], test_indices: [9]
train subjects:
0    p012
1    p012
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: object,
test subjects:
9    p081
Name: subject, dtype: object

其结果按组来划分且没有重复。

这篇关于scikit-learn中常见的train test split的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1092962

相关文章

MySQL 中的 CAST 函数详解及常见用法

《MySQL中的CAST函数详解及常见用法》CAST函数是MySQL中用于数据类型转换的重要函数,它允许你将一个值从一种数据类型转换为另一种数据类型,本文给大家介绍MySQL中的CAST... 目录mysql 中的 CAST 函数详解一、基本语法二、支持的数据类型三、常见用法示例1. 字符串转数字2. 数字

Python中win32包的安装及常见用途介绍

《Python中win32包的安装及常见用途介绍》在Windows环境下,PythonWin32模块通常随Python安装包一起安装,:本文主要介绍Python中win32包的安装及常见用途的相关... 目录前言主要组件安装方法常见用途1. 操作Windows注册表2. 操作Windows服务3. 窗口操作

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

Mysql常见的SQL语句格式及实用技巧

《Mysql常见的SQL语句格式及实用技巧》本文系统梳理MySQL常见SQL语句格式,涵盖数据库与表的创建、删除、修改、查询操作,以及记录增删改查和多表关联等高级查询,同时提供索引优化、事务处理、临时... 目录一、常用语法汇总二、示例1.数据库操作2.表操作3.记录操作 4.高级查询三、实用技巧一、常用语

python 常见数学公式函数使用详解(最新推荐)

《python常见数学公式函数使用详解(最新推荐)》文章介绍了Python的数学计算工具,涵盖内置函数、math/cmath标准库及numpy/scipy/sympy第三方库,支持从基础算术到复杂数... 目录python 数学公式与函数大全1. 基本数学运算1.1 算术运算1.2 分数与小数2. 数学函数

SpringBoot开发中十大常见陷阱深度解析与避坑指南

《SpringBoot开发中十大常见陷阱深度解析与避坑指南》在SpringBoot的开发过程中,即使是经验丰富的开发者也难免会遇到各种棘手的问题,本文将针对SpringBoot开发中十大常见的“坑... 目录引言一、配置总出错?是不是同时用了.properties和.yml?二、换个位置配置就失效?搞清楚加

HTML中meta标签的常见使用案例(示例详解)

《HTML中meta标签的常见使用案例(示例详解)》HTMLmeta标签用于提供文档元数据,涵盖字符编码、SEO优化、社交媒体集成、移动设备适配、浏览器控制及安全隐私设置,优化页面显示与搜索引擎索引... 目录html中meta标签的常见使用案例一、基础功能二、搜索引擎优化(seo)三、社交媒体集成四、移动

python常见环境管理工具超全解析

《python常见环境管理工具超全解析》在Python开发中,管理多个项目及其依赖项通常是一个挑战,下面:本文主要介绍python常见环境管理工具的相关资料,文中通过代码介绍的非常详细,需要的朋友... 目录1. conda2. pip3. uvuv 工具自动创建和管理环境的特点4. setup.py5.

java中long的一些常见用法

《java中long的一些常见用法》在Java中,long是一种基本数据类型,用于表示长整型数值,接下来通过本文给大家介绍java中long的一些常见用法,感兴趣的朋友一起看看吧... 在Java中,long是一种基本数据类型,用于表示长整型数值。它的取值范围比int更大,从-922337203685477