本文主要是介绍DataWhale-(scikit-learn教程)-Task08(可视化总结)-202112,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
西瓜书代码实战
一、决策树可视化
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
import graphviz
# 加载数据集
data = load_iris()
# 转换成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品种列
df['Species'] = data.target# 用数值替代品种名作为标签
target = np.unique(data.target)
target_names = np.unique(data.target_names)
targets = dict(zip(target, target_names))
df['Species'] = df['Species'].replace(targets)# 提取数据和标签
X = df.drop(columns="Species")
y = df["Species"]
feature_names = X.columns
labels = y.unique()X_train, test_x, y_train, test_lab = train_test_split(X,y,test_size = 0.4,random_state = 42)
model = DecisionTreeClassifier(max_depth =3, random_state = 42)
model.fit(X_train, y_train)
1. 文字表示
# 以文字形式输出树
text_representation = tree.export_text(model)
print(text_representation)
2. plot_tree函数
# 用图片画出
plt.figure(figsize=(30,10), facecolor ='g') #
a = tree.plot_tree(model,feature_names = feature_names,class_names = labels,rounded = True,filled = True,fontsize=14)
plt.show()
3. graphviz
# DOT data
dot_data = tree.export_graphviz(model, out_file=None, feature_names=data.feature_names, class_names=data.target_names,filled=True)# Draw graph
graph = graphviz.Source(dot_data, format="png")
graph.render('lense')
二、xgboost可视化
import xgboost
from xgboost import XGBClassifier
from sklearn.datasets import load_irisiris = load_iris()
x, y = iris.data, iris.target
model = XGBClassifier()
model.fit(x, y)
1, 特征重要性
# 如果输入是没有表头的array,会自动以f1,f2开始,需要更换表头
# 画树结构图的时候也需要替换表头
model.get_booster().feature_names = iris.feature_names
# max_num_features指定排名最靠前的多少特征
# height=0.2指定柱状图每个柱子的粗细,默认是0.2
# importance_type='weight'默认是用特征子树中的出现次数(被选择次数),还有"gain"和"cover"
xgboost.plot_importance(model, max_num_features=5)
2. 画树结构
xgboost.to_graphviz(model, num_trees=2) # 索引第2棵树
三、lgbm可视化
#LGB树展示
from sklearn.datasets import load_iris
from sklearn import tree
#import pydotplus
import graphviz
import os
import pandas as pd
import lightgbm as lgb
model=lgb.LGBMClassifier()
model.fit(iris.data,iris.target)
dot_data=lgb.create_tree_digraph(model,tree_index=0)
dot_data.format='PDF'
dot_data.render('lgb_iris_0.pdf')
import matplotlib.pyplot as plt
fig2 = plt.figure(figsize=(20, 20))
ax = fig2.subplots()
lgb.plot_tree(model, tree_index=1, ax=ax)
plt.show()
这篇关于DataWhale-(scikit-learn教程)-Task08(可视化总结)-202112的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!