画学习曲线的方法

2024-05-10 03:18
文章标签 方法 学习曲线

本文主要是介绍画学习曲线的方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

import matplotlib.pyplot as plt
import numpy as np
import pandas as pdfrom sklearn.model_selection import ShuffleSplit 
from sklearn.model_selection import learning_curve
from sklearn.neighbors import KNeighborsClassifier#加载数据
data=pd.read_csv(r"D:\Desktop\data\59024 scikit-learn机器学习源码_20181031\code\datasets\pima-indians-diabetes\diabetes.csv")#from common.utils import plot_learning_curve
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):plt.figure()plt.title(title)if ylim is not None:plt.ylim(*ylim)plt.xlabel("Training examples")plt.ylabel("Score")train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)train_scores_mean = np.mean(train_scores, axis=1)train_scores_std = np.std(train_scores, axis=1)test_scores_mean = np.mean(test_scores, axis=1)test_scores_std = np.std(test_scores, axis=1)plt.grid()# 生成网格plt.fill_between(train_sizes, train_scores_mean - train_scores_std,train_scores_mean + train_scores_std, alpha=0.1,color="r")plt.fill_between(train_sizes, test_scores_mean - test_scores_std,test_scores_mean + test_scores_std, alpha=0.1, color="g")plt.plot(train_sizes, train_scores_mean, 'o-', color="r",label="Training score")plt.plot(train_sizes, test_scores_mean, 'o-', color="g",label="Cross-validation score") plt.legend(loc="best")#添加图例return pltX=data.iloc[:,0:8]
y=data.iloc[:,-1]#pandas对象加iloc,数组不用加
knn=KNeighborsClassifier(n_neighbors=2)cv=ShuffleSplit(n_splits=10,test_size=0.2,random_state=0)
plt.figure(figsize=(16,10),dpi=200)
plot_learning_curve(knn,"Learn Curve for KNN Diabetes",X,y,ylim=(0.0,1.01),cv=cv)

 

 学习曲线说明:

 横坐标表示训练样本的数量,纵坐标表示准确率

1:观察左上图,训练集准确率与验证集准确率收敛,但是两者收敛后的准确率远小于我们的期望准确率(上面那条红线),所以由图可得该模型属于欠拟合-高偏差(underfitting)问题。由于欠拟合,所以我们需要增加模型的复杂度,比如,增加特征、增加树的深度、减小正则项等等,此时再增加数据量是不起作用的。

    2:观察右上图,训练集准确率高于期望值,验证集则低于期望值,两者之间有很大的间距,误差很大,对于新的数据集模型适应性较差,所以由图可得该模型属于过拟合-高方差(overfitting)问题。由于过拟合,所以我们降低模型的复杂度,比如减小树的深度、增大分裂节点样本数、增大样本数、减少特征数等等。

    3:一个比较理想的学习曲线图应当是:低偏差、低方差,即收敛且误差小。
 

这篇关于画学习曲线的方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/975286

相关文章

MySQL查询JSON数组字段包含特定字符串的方法

《MySQL查询JSON数组字段包含特定字符串的方法》在MySQL数据库中,当某个字段存储的是JSON数组,需要查询数组中包含特定字符串的记录时传统的LIKE语句无法直接使用,下面小编就为大家介绍两种... 目录问题背景解决方案对比1. 精确匹配方案(推荐)2. 模糊匹配方案参数化查询示例使用场景建议性能优

关于集合与数组转换实现方法

《关于集合与数组转换实现方法》:本文主要介绍关于集合与数组转换实现方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、Arrays.asList()1.1、方法作用1.2、内部实现1.3、修改元素的影响1.4、注意事项2、list.toArray()2.1、方

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

一文详解Git中分支本地和远程删除的方法

《一文详解Git中分支本地和远程删除的方法》在使用Git进行版本控制的过程中,我们会创建多个分支来进行不同功能的开发,这就容易涉及到如何正确地删除本地分支和远程分支,下面我们就来看看相关的实现方法吧... 目录技术背景实现步骤删除本地分支删除远程www.chinasem.cn分支同步删除信息到其他机器示例步骤

在Golang中实现定时任务的几种高效方法

《在Golang中实现定时任务的几种高效方法》本文将详细介绍在Golang中实现定时任务的几种高效方法,包括time包中的Ticker和Timer、第三方库cron的使用,以及基于channel和go... 目录背景介绍目的和范围预期读者文档结构概述术语表核心概念与联系故事引入核心概念解释核心概念之间的关系

在Linux终端中统计非二进制文件行数的实现方法

《在Linux终端中统计非二进制文件行数的实现方法》在Linux系统中,有时需要统计非二进制文件(如CSV、TXT文件)的行数,而不希望手动打开文件进行查看,例如,在处理大型日志文件、数据文件时,了解... 目录在linux终端中统计非二进制文件的行数技术背景实现步骤1. 使用wc命令2. 使用grep命令

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

XML重复查询一条Sql语句的解决方法

《XML重复查询一条Sql语句的解决方法》文章分析了XML重复查询与日志失效问题,指出因DTO缺少@Data注解导致日志无法格式化、空指针风险及参数穿透,进而引发性能灾难,解决方案为在Controll... 目录一、核心问题:从SQL重复执行到日志失效二、根因剖析:DTO断裂引发的级联故障三、解决方案:修复

C++ 检测文件大小和文件传输的方法示例详解

《C++检测文件大小和文件传输的方法示例详解》文章介绍了在C/C++中获取文件大小的三种方法,推荐使用stat()函数,并详细说明了如何设计一次性发送压缩包的结构体及传输流程,包含CRC校验和自动解... 目录检测文件的大小✅ 方法一:使用 stat() 函数(推荐)✅ 用法示例:✅ 方法二:使用 fsee

Java继承映射的三种使用方法示例

《Java继承映射的三种使用方法示例》继承在Java中扮演着重要的角色,它允许我们创建一个类(子类),该类继承另一个类(父类)的所有属性和方法,:本文主要介绍Java继承映射的三种使用方法示例,需... 目录前言一、单表继承(Single Table Inheritance)1-1、原理1-2、使用方法1-