scikit-learn中常见的train test split

2024-08-21 10:58

本文主要是介绍scikit-learn中常见的train test split,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

1. train_test_split

进行一次性划分

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)
"""X: array([[0, 1],[2, 3],[4, 5],[6, 7],[8, 9]])
list(y): [0, 1, 2, 3, 4]
"""X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)""">>> X_trainarray([[4, 5],[0, 1],[6, 7]])>>> y_train[2, 0, 3]>>> X_testarray([[2, 3],[8, 9]])>>> y_test[1, 4]
"""
train_test_split(y, shuffle=False)[[0, 1, 2], [3, 4]]
  • X, y: 可为lists, numpy arrays, scipy-sparse, matrices或者dataframes

2. ShuffleSplit

sklearn.model_selection.ShuffleSplit用来将数据集分为测试集和验证集,可以多次划分

from sklearn.model_selection import ShuffleSplit
import numpy as npX, y = np.arange(20).reshape((10, 2)), range(10)ss = ShuffleSplit(n_splits=10, test_size=0.2, train_size=None, random_state=None)for train_indices, test_indices in ss.split(sample):print(f"train_indices: {train_indices}, test_indices: {test_indices}")

输出:

train_indices: [4 3 0 6 8 1 9 2], test_indices: [7 5]
train_indices: [0 5 3 4 2 6 9 8], test_indices: [1 7]
train_indices: [2 0 4 1 7 6 3 9], test_indices: [5 8]
train_indices: [2 6 9 8 5 3 4 1], test_indices: [0 7]
train_indices: [0 8 7 9 4 5 2 1], test_indices: [6 3]
train_indices: [6 5 2 8 1 0 3 4], test_indices: [9 7]
train_indices: [8 4 9 5 0 3 2 6], test_indices: [1 7]
train_indices: [6 5 2 1 4 3 0 7], test_indices: [8 9]
train_indices: [8 9 1 7 4 6 5 3], test_indices: [0 2]
train_indices: [1 3 9 5 0 2 7 6], test_indices: [4 8]
  • n_splits: int, 划分训练集、测试集的次数,默认为10
  • test_size: float, int, None, default=0.1; 测试集比例或样本数量,该值为[0.0, 1.0]内的浮点数时,表示测试集占总样本的比例;该值为整型值时,表示具体的测试集样本数量。
    方法:
  • get_n_splits:获取分割次数
  • split(X, y=None, groups=None): 进行分割,并返回索引

3. GroupShuffleSplit

ShuffleSplit基本相同,区别在于先进行分组,然后按照分组来进行训练集和验证集划分。

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupShuffleSplit
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],'classname':['c5','c0','c1','c5','c0','c0','c1','c1','c2','c6'],'img':['img_41179.jpg','img_50749.jpg','img_53609.jpg','img_52213.jpg','img_72495.jpg', 'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})gss = GroupShuffleSplit(n_splits=4, test_size=0.25, random_state=0)tmp_groups = sample.loc[:, 'subject'].values# 进行一次划分
train_idxs, test_idxs =next(gss.split(X=sample['img'], y=sample['classname'], groups=tmp_groups))# 进行多次划分
for train_indices, test_indices in gss.split(sample.loc[:, "img"], sample.loc[:, "classname"], groups=tmp_groups):print(f"\ntrain_indices: {train_indices}, test_indices: {test_indices}")print(f"train subjects: {sample.loc[train_indices, 'subject']}, test subjects: {sample.loc[test_indices, 'subject']}")

输出:

fold====0=====
train_indices: [0 1 2 3 4 9], test_indices: [5 6 7 8]
train subjects: 0    p012
1    p012
2    p014
3    p014
4    p014
9    p081
Name: subject, dtype: object, test subjects: 5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: objectfold====1=====
train_indices: [2 3 4 5 6 7 8 9], test_indices: [0 1]
train subjects: 2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object, test subjects: 0    p012
1    p012
Name: subject, dtype: objectfold====2=====
train_indices: [0 1 2 3 4 5 6 7 8], test_indices: [9]
train subjects: 0    p012
1    p012
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: object, test subjects: 9    p081
Name: subject, dtype: objectfold====3=====
train_indices: [0 1 5 6 7 8 9], test_indices: [2 3 4]
train subjects: 0    p012
1    p012
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object, test subjects: 2    p014
3    p014
4    p014
Name: subject, dtype: object

可以看出已经进行了分组之后再进行划分。

4. GroupKFold

GroupKFoldGroupShuffleSplit基本相同,区别是GroupShuffleSplit每次都是做独立的划分,不同的划分之家可能会重叠。而GroupKFold则没有重叠,因此没有所谓的test_sizerandom_state参数。

import pandas as pd
import numpy as np
from sklearn.model_selection import ShuffleSplit, GroupKFold
sample = pd.DataFrame({'subject':['p012', 'p012', 'p014', 'p014', 'p014', 'p024', 'p024', 'p024', 'p024', 'p081'],'classname':['c5','c0','c1','c5','c0','c0','c1','c1','c2','c6'],'img':['img_41179.jpg','img_50749.jpg','img_53609.jpg','img_52213.jpg','img_72495.jpg', 'img_66836.jpg','img_32639.jpg','img_31777.jpg','img_97535.jpg','img_1399.jpg']})gkf = GroupKFold(n_splits=4)tmp_groups = sample.loc[:, 'subject'].values# 进行一次划分
train_idxs, test_idxs =next(gkf.split(X=sample['img'], y=sample['classname'], groups=tmp_groups))# 进行多次划分
for train_indices, test_indices in gkf.split(sample.loc[:, "img"], sample.loc[:, "classname"], groups=tmp_groups):print(f"\ntrain_indices: {train_indices}, test_indices: {test_indices}")print(f"train subjects: \n{sample.loc[train_indices, 'subject']}, \ntest subjects: \n{sample.loc[test_indices, 'subject']}")

输出:

train_indices: [0 1 2 3 4 9], test_indices: [5 6 7 8]
train subjects:
0    p012
1    p012
2    p014
3    p014
4    p014
9    p081
Name: subject, dtype: object,
test subjects:
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: objecttrain_indices: [0 1 5 6 7 8 9], test_indices: [2 3 4]
train subjects:
0    p012
1    p012
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object,
test subjects:
2    p014
3    p014
4    p014
Name: subject, dtype: objecttrain_indices: [2 3 4 5 6 7 8 9], test_indices: [0 1]
train subjects:
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
9    p081
Name: subject, dtype: object,
test subjects:
0    p012
1    p012
Name: subject, dtype: objecttrain_indices: [0 1 2 3 4 5 6 7 8], test_indices: [9]
train subjects:
0    p012
1    p012
2    p014
3    p014
4    p014
5    p024
6    p024
7    p024
8    p024
Name: subject, dtype: object,
test subjects:
9    p081
Name: subject, dtype: object

其结果按组来划分且没有重复。

这篇关于scikit-learn中常见的train test split的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1092962

相关文章

Spring常见错误之Web嵌套对象校验失效解决办法

《Spring常见错误之Web嵌套对象校验失效解决办法》:本文主要介绍Spring常见错误之Web嵌套对象校验失效解决的相关资料,通过在Phone对象上添加@Valid注解,问题得以解决,需要的朋... 目录问题复现案例解析问题修正总结  问题复现当开发一个学籍管理系统时,我们会提供了一个 API 接口去

C语言线程池的常见实现方式详解

《C语言线程池的常见实现方式详解》本文介绍了如何使用C语言实现一个基本的线程池,线程池的实现包括工作线程、任务队列、任务调度、线程池的初始化、任务添加、销毁等步骤,感兴趣的朋友跟随小编一起看看吧... 目录1. 线程池的基本结构2. 线程池的实现步骤3. 线程池的核心数据结构4. 线程池的详细实现4.1 初

bytes.split的用法和注意事项

当然,我很乐意详细介绍 bytes.Split 的用法和注意事项。这个函数是 Go 标准库中 bytes 包的一个重要组成部分,用于分割字节切片。 基本用法 bytes.Split 的函数签名如下: func Split(s, sep []byte) [][]byte s 是要分割的字节切片sep 是用作分隔符的字节切片返回值是一个二维字节切片,包含分割后的结果 基本使用示例: pa

JVM 常见异常及内存诊断

栈内存溢出 栈内存大小设置:-Xss size 默认除了window以外的所有操作系统默认情况大小为 1MB,window 的默认大小依赖于虚拟机内存。 栈帧过多导致栈内存溢出 下述示例代码,由于递归深度没有限制且没有设置出口,每次方法的调用都会产生一个栈帧导致了创建的栈帧过多,而导致内存溢出(StackOverflowError)。 示例代码: 运行结果: 栈帧过大导致栈内存

论文翻译:ICLR-2024 PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS

PROVING TEST SET CONTAMINATION IN BLACK BOX LANGUAGE MODELS https://openreview.net/forum?id=KS8mIvetg2 验证测试集污染在黑盒语言模型中 文章目录 验证测试集污染在黑盒语言模型中摘要1 引言 摘要 大型语言模型是在大量互联网数据上训练的,这引发了人们的担忧和猜测,即它们可能已

模拟实现vector中的常见接口

insert void insert(iterator pos, const T& x){if (_finish == _endofstorage){int n = pos - _start;size_t newcapacity = capacity() == 0 ? 2 : capacity() * 2;reserve(newcapacity);pos = _start + n;//防止迭代

【Kubernetes】常见面试题汇总(三)

目录 9.简述 Kubernetes 的缺点或当前的不足之处? 10.简述 Kubernetes 相关基础概念? 9.简述 Kubernetes 的缺点或当前的不足之处? Kubernetes 当前存在的缺点(不足)如下: ① 安装过程和配置相对困难复杂; ② 管理服务相对繁琐; ③ 运行和编译需要很多时间; ④ 它比其他替代品更昂贵; ⑤ 对于简单的应用程序来说,可能不

【附答案】C/C++ 最常见50道面试题

文章目录 面试题 1:深入探讨变量的声明与定义的区别面试题 2:编写比较“零值”的`if`语句面试题 3:深入理解`sizeof`与`strlen`的差异面试题 4:解析C与C++中`static`关键字的不同用途面试题 5:比较C语言的`malloc`与C++的`new`面试题 6:实现一个“标准”的`MIN`宏面试题 7:指针是否可以是`volatile`面试题 8:探讨`a`和`&a`

Golang test编译使用

创建文件my_test.go package testsimport "testing"func TestMy(t *testing.T) {t.Log("TestMy")} 通常用法: $ go test -v -run TestMy my_test.go=== RUN TestMyTestMy: my_test.go:6: TestMy--- PASS: TestMy (0.

常见的服务器

常见的Web服务器 1、Tomcat:Tomcat和Java结合得最好,是Oracle官方推荐的JSP服务器。Tomcat是开源的Web服务器,经过长时间的发展,性能、稳定性等方面都非常优秀。 2、Jetty:另一个优秀的Web服务器。Jetty有个更大的优点是,Jetty可作为一个嵌入式服务器,即:如果在应用中加入Jetty的JAR文件,应用可在代码中对外提供Web服务。 3、Resin: