scikit-learn中常见的train test split

2024-08-21 10:58

本文主要是介绍scikit-learn中常见的train test split,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. train_test_split

进行一次性划分

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)
"""X: array([[0, 1],[2, 3],[4, 5],[6, 7],[8, 9]])
list(y): [0, 1, 2, 3, 4]
"""X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)""">>> X_trainarray([[4, 5],[0, 1],[6, 7]])>>> y_train[2, 0, 3]>>> X_testarray([[2, 3],[8, 9]])>>> y_test[1, 4]
"""
train_test_split(y, shuffle=False)[[0, 1, 2], [3, 4]]
  • X, y: 可为lists, numpy arrays, scipy-sparse, matrices或者dataframes

2. ShuffleSplit

sklearn.model_selection.ShuffleSplit用来将数据集分为测试集和验证集,可以多次划分

from sklearn.model_selection import ShuffleSplit
import numpy as npX, y = np.arange(20).reshape((10, 2)), range(10)ss = ShuffleSplit(n_splits=10, test_size=0.2, train_size=None, random_state=None)for train_indices, test_indices in ss.split(sample):print(f"train_indices: {train_indices}, test_indices: {test_indices}")

输出:

train_indices: [4 3 0 6 8 1 9 2], test_indices: [7 5]
train_indices: [0 5 3 4 2 6 9 8], test_indices: [1 7]
train_indices: [2 0 4 1 7 6 3 9], test_indices: [5 8]
train_indices: [2 6 9 8 5 3 4 1], test_indices: [0 7]
train_indices: [0 8 7 9 4 5 2 1], test_indices: [6 3]
train_indices: [6 5 2 8 1 0 3 4], test_indices: [9 7]
train_indices: [8 4 9 5 0 3 2 6], test_indices: [1 7]
train_indices: [6 5 2 1 4 3 0 7], test_indices: [8 9]
train_indices: [8 9 1 7 4 6 5 3], test_indices: [0 2]
train_indices: [1 3 9 5 0 2 7 6], test_indices: [4 8]
  • n_splits: int, 划分训练集、测试集的次数,默认为10
  • test_size: float, int, None, default=0.1; 测试集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示测试集占总样本的比例;该值为整型值时,表示具体的测试集样本数量。
    方法:
  • get_n_splits:获取分割次数
  • split(X, y=None, groups=None): 进行分割,并返回索引

3. GroupShuffleSplit

ShuffleSplit基本相同,区别在于先进行分组,然后按照分组来进行训练集和验证集划分。

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupShuffleSplit
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],'classname':['c5','c0','c1','c5','c0','c0','c1','c1','c2','c6'],'img':['img_41179.jpg','img_50749.jpg','img_53609.jpg','img_52213.jpg','img_72495.jpg', 'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})gss = GroupShuffleSplit(n_splits=4, test_size=0.25, random_state=0)tmp_groups = sample.loc[:, 'subject'].values# 进行一次划分
train_idxs, test_idxs =next(gss.split(X=sample['img'], y=sample['classname'], groups=tmp_groups))# 进行多次划分
for train_indices, test_indices in gss.split(sample.loc[:, "img"], sample.loc[:, "classname"], groups=tmp_groups):print(f"\ntrain_indices: {train_indices}, test_indices: {test_indices}")print(f"train subjects: {sample.loc[train_indices, 'subject']}, test subjects: {sample.loc[test_indices, 'subject']}")

输出:

fold====0=====
train_indices: [0 1 2 3 4 9], test_indices: [5 6 7 8]
train subjects: 0    p012
1    p012
2    p014
3    p014
4    p014
9    p081
Name: subject, dtype: object, test subjects: 5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: objectfold====1=====
train_indices: [2 3 4 5 6 7 8 9], test_indices: [0 1]
train subjects: 2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object, test subjects: 0    p012
1    p012
Name: subject, dtype: objectfold====2=====
train_indices: [0 1 2 3 4 5 6 7 8], test_indices: [9]
train subjects: 0    p012
1    p012
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: object, test subjects: 9    p081
Name: subject, dtype: objectfold====3=====
train_indices: [0 1 5 6 7 8 9], test_indices: [2 3 4]
train subjects: 0    p012
1    p012
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object, test subjects: 2    p014
3    p014
4    p014
Name: subject, dtype: object

可以看出已经进行了分组之后再进行划分。

4. GroupKFold

GroupKFoldGroupShuffleSplit基本相同,区别是GroupShuffleSplit每次都是做独立的划分,不同的划分之家可能会重叠。而GroupKFold则没有重叠,因此没有所谓的test_sizerandom_state参数。

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupKFold
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],'classname':['c5','c0','c1','c5','c0','c0','c1','c1','c2','c6'],'img':['img_41179.jpg','img_50749.jpg','img_53609.jpg','img_52213.jpg','img_72495.jpg', 'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})gkf = GroupKFold(n_splits=4)tmp_groups = sample.loc[:, 'subject'].values# 进行一次划分
train_idxs, test_idxs =next(gkf.split(X=sample['img'], y=sample['classname'], groups=tmp_groups))# 进行多次划分
for train_indices, test_indices in gkf.split(sample.loc[:, "img"], sample.loc[:, "classname"], groups=tmp_groups):print(f"\ntrain_indices: {train_indices}, test_indices: {test_indices}")print(f"train subjects: \n{sample.loc[train_indices, 'subject']}, \ntest subjects: \n{sample.loc[test_indices, 'subject']}")

输出:

train_indices: [0 1 2 3 4 9], test_indices: [5 6 7 8]
train subjects:
0    p012
1    p012
2    p014
3    p014
4    p014
9    p081
Name: subject, dtype: object,
test subjects:
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: objecttrain_indices: [0 1 5 6 7 8 9], test_indices: [2 3 4]
train subjects:
0    p012
1    p012
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object,
test subjects:
2    p014
3    p014
4    p014
Name: subject, dtype: objecttrain_indices: [2 3 4 5 6 7 8 9], test_indices: [0 1]
train subjects:
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object,
test subjects:
0    p012
1    p012
Name: subject, dtype: objecttrain_indices: [0 1 2 3 4 5 6 7 8], test_indices: [9]
train subjects:
0    p012
1    p012
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: object,
test subjects:
9    p081
Name: subject, dtype: object

其结果按组来划分且没有重复。

这篇关于scikit-learn中常见的train test split的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1092962

相关文章

python中各种常见文件的读写操作与类型转换详细指南

《python中各种常见文件的读写操作与类型转换详细指南》这篇文章主要为大家详细介绍了python中各种常见文件(txt,xls,csv,sql,二进制文件)的读写操作与类型转换,感兴趣的小伙伴可以跟... 目录1.文件txt读写标准用法1.1写入文件1.2读取文件2. 二进制文件读取3. 大文件读取3.1

C++中初始化二维数组的几种常见方法

《C++中初始化二维数组的几种常见方法》本文详细介绍了在C++中初始化二维数组的不同方式,包括静态初始化、循环、全部为零、部分初始化、std::array和std::vector,以及std::vec... 目录1. 静态初始化2. 使用循环初始化3. 全部初始化为零4. 部分初始化5. 使用 std::a

前端下载文件时如何后端返回的文件流一些常见方法

《前端下载文件时如何后端返回的文件流一些常见方法》:本文主要介绍前端下载文件时如何后端返回的文件流一些常见方法,包括使用Blob和URL.createObjectURL创建下载链接,以及处理带有C... 目录1. 使用 Blob 和 URL.createObjectURL 创建下载链接例子:使用 Blob

C++ vector的常见用法超详细讲解

《C++vector的常见用法超详细讲解》:本文主要介绍C++vector的常见用法,包括C++中vector容器的定义、初始化方法、访问元素、常用函数及其时间复杂度,通过代码介绍的非常详细,... 目录1、vector的定义2、vector常用初始化方法1、使编程用花括号直接赋值2、使用圆括号赋值3、ve

Pytest多环境切换的常见方法介绍

《Pytest多环境切换的常见方法介绍》Pytest作为自动化测试的主力框架,如何实现本地、测试、预发、生产环境的灵活切换,本文总结了通过pytest框架实现自由环境切换的几种方法,大家可以根据需要进... 目录1.pytest-base-url2.hooks函数3.yml和fixture结论你是否也遇到过

C/C++错误信息处理的常见方法及函数

《C/C++错误信息处理的常见方法及函数》C/C++是两种广泛使用的编程语言,特别是在系统编程、嵌入式开发以及高性能计算领域,:本文主要介绍C/C++错误信息处理的常见方法及函数,文中通过代码介绍... 目录前言1. errno 和 perror()示例:2. strerror()示例:3. perror(

Go标准库常见错误分析和解决办法

《Go标准库常见错误分析和解决办法》Go语言的标准库为开发者提供了丰富且高效的工具,涵盖了从网络编程到文件操作等各个方面,然而,标准库虽好,使用不当却可能适得其反,正所谓工欲善其事,必先利其器,本文将... 目录1. 使用了错误的time.Duration2. time.After导致的内存泄漏3. jsO

MyBatis 动态 SQL 优化之标签的实战与技巧(常见用法)

《MyBatis动态SQL优化之标签的实战与技巧(常见用法)》本文通过详细的示例和实际应用场景,介绍了如何有效利用这些标签来优化MyBatis配置,提升开发效率,确保SQL的高效执行和安全性,感... 目录动态SQL详解一、动态SQL的核心概念1.1 什么是动态SQL?1.2 动态SQL的优点1.3 动态S

java常见报错及解决方案总结

《java常见报错及解决方案总结》:本文主要介绍Java编程中常见错误类型及示例,包括语法错误、空指针异常、数组下标越界、类型转换异常、文件未找到异常、除以零异常、非法线程操作异常、方法未定义异常... 目录1. 语法错误 (Syntax Errors)示例 1:解决方案:2. 空指针异常 (NullPoi

C++常见容器获取头元素的方法大全

《C++常见容器获取头元素的方法大全》在C++编程中,容器是存储和管理数据集合的重要工具,不同的容器提供了不同的接口来访问和操作其中的元素,获取容器的头元素(即第一个元素)是常见的操作之一,本文将详细... 目录一、std::vector二、std::list三、std::deque四、std::forwa