本文主要是介绍scikit-learn/ID3算法使用GridSearchCV调优,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
环境:python 3,scikit-learn 0.18
#coding:utf-8
"""
python 3
scikit-learn 0.18
"""
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import input_data
import numpy as npmnist = input_data.read_data_sets('mnist/',one_hot=False)
x = mnist.train.images
y = mnist.train.labelstrain_data,validation_data,train_labels,validation_labels = train_test_split(x,y,test_size=0.2)
#使用GridSearchCV找到最优参数
dtree = DecisionTreeClassifier(random_state=0)
#gini ,表示决策树非叶节点划分依据是根 据 gini 指数表示划分的纯度。
#entropy ,用信息增益来衡量 划分的优劣
criterion_options = ['gini','entropy']
splitter_options = ['best','random']
param_griddtree = dict(criterion=criterion_options,splitter=splitter_options)
griddtree = GridSearchCV(dtree,param_griddtree,cv=10,scoring='accuracy',verbose=1)
griddtree.fit(train_data,train_labels)
print('best score is:',str(griddtree.best_score_))
print('best params are :',str(griddtree.best_params_))
结果
耗时4.5min找到最优参数
这篇关于scikit-learn/ID3算法使用GridSearchCV调优的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!