ID3算法 决策树学习 Python实现

2024-01-30 14:36

本文主要是介绍ID3算法 决策树学习 Python实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

算法流程

输入:约束决策树生长参数(最大深度,节点最小样本数,可选),训练集(特征值离散或连续,标签离散)。
输出:决策树。
过程:每次选择信息增益最大的属性决策分类,直到当前节点样本均为同一类,或者信息增益过小。

信息增益

设样本需分为 K K K 类,当前节点待分类样本中每类样本的个数分别为 n 1 , n 2 , … , n K n_1, n_2, …, n_K n1,n2,,nK,则该节点信息熵为
I ( n 1 , n 2 , … , n K ) = − ∑ i = 1 K n i ∑ j = 1 K n j log ⁡ 2 n i ∑ j = 1 K n j I(n_1, n_2, …, n_K) = -\sum_{i=1}^K \frac{n_i}{\sum_{j=1}^K n_j} \log_2 \frac{n_i}{\sum_{j=1}^K n_j} I(n1,n2,,nK)=i=1Kj=1Knjnilog2j=1Knjni
设属性 A A A v v v 种取值,当前节点样本按属性 A A A 决策分类为 v v v 个子节点,第 i i i 个子节点待分类样本中每类样本的个数分别为 n i 1 , n i 2 , … , n i K n_{i1}, n_{i2}, …, n_{iK} ni1,ni2,,niK,则父节点按属性 A A A 决策分类的类信息熵为
E ( A ) = ∑ i = 1 v ∑ j = 1 K n i j ∑ j = 1 K n j I ( n i 1 , n i 2 , … , n i K ) E(A) = \sum_{i=1}^v \frac{\sum_{j=1}^K n_{ij}}{\sum_{j=1}^K n_j} I(n_{i1}, n_{i2}, …, n_{iK}) E(A)=i=1vj=1Knjj=1KnijI(ni1,ni2,,niK)
由此计算当前节点在属性 A 上的信息增益为
G a i n ( A ) = I ( n 1 , n 2 , … , n K ) − E ( A ) Gain(A) = I(n_1, n_2, …, n_K) - E(A) Gain(A)=I(n1,n2,,nK)E(A)

决策树学习过程中可能出现的问题与解决方法

不相关属性(irrelevant attribute),属性与类分布相独立。此情况下信息增益过小,可以终止决策,将当前节点标签设为最高频类。
不充足属性(inadequate attribute),不同类的样本有完全相同特征。此情况下信息增益为 0 0 0,可以终止决策,将当前节点标签设为最高频类。
未知属性值(unknown value),数据集中某些属性值不确定。可以通过预处理剔除含有未知属性值的样本或属性。
过拟合(overfitting),决策树泛化能力不足。可以约束决策树生长参数。
空分支(empty branch),学习过程中某节点样本数为 0 0 0 v ≥ 3 v≥3 v3 才会发生。可以将当前节点标签设为父节点的最高频类。

参考代码如下(仅能处理离散属性值状态):

import numpy as np
class ID3:def __init__(self, max_depth = 0, min_samples_split = 0):self.max_depth, self.min_samples_split = max_depth, min_samples_splitdef __EI(self, *n):n = np.array([i for i in n if i > 0])if n.shape[0] <= 1:return 0p = n / np.sum(n)return -np.dot(p, np.log2(p))def __Gain(self, A: np.ndarray):return self.__EI(*np.sum(A, axis = 0)) - np.average(np.frompyfunc(self.__EI, A.shape[1], 1)(*A.T), weights = np.sum(A, axis = 1))def fit(self, X: np.ndarray, y):self.DX, (self.Dy, yn) = [np.unique(X[:, i]) for i in range(X.shape[1])], np.unique(y, return_inverse = True)self.Dy: np.ndarrayself.value = []def fitcur(n, h, p = 0):self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))r: np.ndarray = np.unique(y[n])if r.shape[0] == 0: # Empty Branchreturn pelif r.shape[0] == 1:return yn[n[0]]elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfittingreturn np.argmax(np.bincount(yn[n]))else:P = [[n[np.where(X[n, i] == j)[0]] for j in self.DX[i]] for i in range(X.shape[1])]G = [self.__Gain(A) for A in [np.array([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P]]m = np.argmax(G)if(G[m] < 1e-9): # Inadequate attributereturn np.argmax(np.bincount(yn[n]))return (m,) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m])self.tree = fitcur(np.arange(X.shape[0]), 0)def predict(self, X):def precur(n, x):return precur(n[1 + np.where(self.DX[n[0]] == x[n[0]])[0][0]], x) if isinstance(n, tuple) else self.Dy[n]return np.array([precur(self.tree, x) for x in X])def visualize(self, header):i = iter(self.value)def visval():v = next(i)print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')def viscur(n, h, c):for i in h[:-1]:print('%c   ' % ('│' if i else ' '), end = '')if len(h) > 0:print('%c── ' % ('├' if h[-1] else '└'), end = '')print('[%s] ' % c, end = '')if isinstance(n, tuple):print(header[n[0]], end = '')visval()print()for i in range(len(n) - 1):viscur(n[i + 1], h + [i < len(n) - 2], str(self.DX[n[0]][i]))else:print(self.Dy[n], end = '')visval()print()viscur(self.tree, [], '')

连续属性值的离散化

对于某个连续属性,取训练集中所有属性值的相邻两点中点生成界点集,按每个界点将当前节点样本分为 2 2 2 类,算出界点集中最大信息增益的界点。

在上文代码的基础上加以改动,得到能处理连续属性值状态的代码如下:

import numpy as np
class ID3:def __init__(self, max_depth = 0, min_samples_split = 0):self.max_depth, self.min_samples_split = max_depth, min_samples_splitdef __EI(self, *n):n = np.array([i for i in n if i > 0])if n.shape[0] <= 1:return 0p = n / np.sum(n)return -np.dot(p, np.log2(p))def __Gain(self, A: np.ndarray):return self.__EI(*np.sum(A, axis = 0)) - np.average([self.__EI(*a) for a in A], weights = np.sum(A, axis = 1))def fit(self, X: np.ndarray, y):self.c = np.array([(np.all([isinstance(j, (int, float)) for j in i])) for i in X.T])self.DX, (self.Dy, yn) = [np.unique(X[:, i]) if not self.c[i] else None for i in range(X.shape[1])], np.unique(y, return_inverse = True)self.Dy: np.ndarrayself.value = []def Part(n, a):if self.c[a]:u = np.sort(np.unique(X[n, a]))if(u.shape[0] < 2):return Nonev = np.array([(u[i - 1] + u[i]) / 2 for i in range(1, u.shape[0])])P = [[n[np.where(X[n, a] < i)[0]], n[np.where(X[n, a] >= i)[0]]] for i in v]m = np.argmax([self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P])return v[m], P[m]else:return None, [n[np.where(X[n, a] == i)[0]] for i in self.DX[a]]def fitcur(n: np.ndarray, h, p = 0):self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))r: np.ndarray = np.unique(y[n])if r.shape[0] == 0: # Empty Branchreturn pelif r.shape[0] == 1:return yn[n[0]]elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfittingreturn np.argmax(np.bincount(yn[n]))else:P = [Part(n, i) for i in range(X.shape[1])]G = [self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p[1]]) if p != None else 0 for p in P]m = np.argmax(G)if(G[m] < 1e-9): # Inadequate attributereturn np.argmax(np.bincount(yn[n]))return ((m, P[m][0]) if self.c[m] else (m,)) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m][1])self.tree = fitcur(np.arange(X.shape[0]), 0)def predict(self, X):def precur(n, x):return precur(n[(2 if x[n[0]] < n[1] else 3) if self.c[n[0]] else (1 + np.where(self.DX[n[0]] == x[n[0]])[0][0])], x) if isinstance(n, tuple) else self.Dy[n]return np.array([precur(self.tree, x) for x in X])def visualize(self, header):i = iter(self.value)def visval():v = next(i)print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')def viscur(n, h, c):for i in h[:-1]:print('%c   ' % ('│' if i else ' '), end = '')if len(h) > 0:print('%c── ' % ('├' if h[-1] else '└'), end = '')print('[%s] ' % c, end = '')if isinstance(n, tuple):print(header[n[0]], end = '')visval()print()if self.c[n[0]]:for i in range(2):viscur(n[2 + i], h + [i < 1], ('< ', '>= ')[i] + str(n[1]))else:for i in range(len(n) - 1):viscur(n[1 + i], h + [i < len(n) - 2], str(self.DX[n[0]][i]))else:print(self.Dy[n], end = '')visval()print()viscur(self.tree, [], '')

实验测试

实验使用数据集如下:
Play tennis 数据集(来源:kaggle):离散属性
Mushroom classification 数据集(来源:kaggle):离散属性
Carsdata 数据集(来源:kaggle):连续属性
Iris 数据集(来源:sklearn.datasets):连续属性

其中 play_tennis.csv 内容如下:

dayoutlooktemphumiditywindplay
D1SunnyHotHighWeakNo
D2SunnyHotHighStrongNo
D3OvercastHotHighWeakYes
D4RainMildHighWeakYes
D5RainCoolNormalWeakYes
D6RainCoolNormalStrongNo
D7OvercastCoolNormalStrongYes
D8SunnyMildHighWeakNo
D9SunnyCoolNormalWeakYes
D10RainMildNormalWeakYes
D11SunnyMildNormalStrongYes
D12OvercastMildHighStrongYes
D13OvercastHotNormalWeakYes
D14RainMildHighStrongNo

Play tennis 数据集上的测试

默认属性二分类测试,代码如下:

import pandas as pd
class Datasets:def __init__(self, fn):self.df = pd.read_csv('Datasets\\%s' % fn).map(lambda x: x.strip() if isinstance(x, str) else x)self.df.rename(columns = lambda x: x.strip(), inplace = True)def getData(self, DX, Dy, drop = False):dfn = self.df.loc[~self.df.eq('').any(axis = 1)].apply(pd.to_numeric, errors = 'ignore') if drop else self.dfreturn dfn[DX].to_numpy(dtype = np.object_), dfn[Dy].to_numpy(dtype = np.object_)# play_tennis.csv
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
dt11 = ID3()
dt11.fit(X, y)
dt11.visualize(a)
print()

结果如下:

outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])├── [High] No (entropy = 0, samples = 3, value = [3 0])└── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])

不充足属性测试

更换属性三分类,不充足属性测试,代码如下:

# play_tennis.csv for inadequate attribute test and class > 2
a = ['temp', 'humidity', 'wind', 'play']
X, y = Datasets('play_tennis.csv').getData(a, 'outlook')
dt12 = ID3(10)
dt12.fit(X, y)
dt12.visualize(a)
print(dt12.predict([['Cool', 'Normal', 'Weak', 'Yes']]))
print()

结果如下:

play (entropy = 1.5774062828523454, samples = 14, value = [4 5 5])
├── [No] temp (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   ├── [Cool] Rain (entropy = 0, samples = 1, value = [0 1 0])
│   ├── [Hot] Sunny (entropy = 0, samples = 2, value = [0 0 2])
│   └── [Mild] wind (entropy = 1.0, samples = 2, value = [0 1 1])
│       ├── [Strong] Rain (entropy = 0, samples = 1, value = [0 1 0])
│       └── [Weak] Sunny (entropy = 0, samples = 1, value = [0 0 1])
└── [Yes] temp (entropy = 1.5304930567574826, samples = 9, value = [4 3 2])├── [Cool] wind (entropy = 1.584962500721156, samples = 3, value = [1 1 1])│   ├── [Strong] Overcast (entropy = 0, samples = 1, value = [1 0 0])│   └── [Weak] Rain (entropy = 1.0, samples = 2, value = [0 1 1])├── [Hot] Overcast (entropy = 0, samples = 2, value = [2 0 0])└── [Mild] wind (entropy = 1.5, samples = 4, value = [1 2 1])├── [Strong] humidity (entropy = 1.0, samples = 2, value = [1 0 1])│   ├── [High] Overcast (entropy = 0, samples = 1, value = [1 0 0])│   └── [Normal] Sunny (entropy = 0, samples = 1, value = [0 0 1])└── [Weak] Rain (entropy = 0, samples = 2, value = [0 2 0])
['Rain']

空分支测试

默认属性二分类,修改部分数据,空分支测试,代码如下:

# play_tennis.csv modified to generate empty branch
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
X[2, 2], X[13, 2] = 'Low', 'Low'
dt13 = ID3()
dt13.fit(X, y)
dt13.visualize(a)
print(dt13.predict([['Sunny', 'Hot', 'Low', 'Weak']]))
print()

结果如下:

outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])├── [High] No (entropy = 0, samples = 3, value = [3 0])├── [Low] No (entropy = 0, samples = 0, value = [0 0])└── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])
['No']

Mushroom classification 数据集上的测试

默认属性二分类,忽略有未知值的属性,划分训练集和测试集,代码如下:

# mushrooms.csv ignoring attribute 'stalk-root' with unknown value
a = ['cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor', 'gill-attachment', 'gill-spacing', 'gill-size','gill-color', 'stalk-shape', 'stalk-surface-above-ring', 'stalk-surface-below-ring', 'stalk-color-above-ring', 'stalk-color-below-ring', 'veil-type','veil-color', 'ring-number', 'ring-type', 'spore-print-color', 'population', 'habitat']
X, y = Datasets('mushrooms.csv').getData(a, 'class')
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt21 = ID3()
dt21.fit(X_train, y_train)
dt21.visualize(a)
print()
y_pred = dt21.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

odor (entropy = 0.9990161113058208, samples = 6093, value = [3159 2934])
├── [a] e (entropy = 0, samples = 298, value = [298   0])
├── [c] p (entropy = 0, samples = 138, value = [  0 138])
├── [f] p (entropy = 0, samples = 1636, value = [   0 1636])
├── [l] e (entropy = 0, samples = 297, value = [297   0])
├── [m] p (entropy = 0, samples = 26, value = [ 0 26])
├── [n] spore-print-color (entropy = 0.19751069442516636, samples = 2645, value = [2564   81])
│   ├── [b] e (entropy = 0, samples = 32, value = [32  0])
│   ├── [h] e (entropy = 0, samples = 35, value = [35  0])
│   ├── [k] e (entropy = 0, samples = 974, value = [974   0])
│   ├── [n] e (entropy = 0, samples = 1013, value = [1013    0])
│   ├── [o] e (entropy = 0, samples = 33, value = [33  0])
│   ├── [r] p (entropy = 0, samples = 50, value = [ 0 50])
│   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   ├── [w] habitat (entropy = 0.34905151737109524, samples = 473, value = [442  31])
│   │   ├── [d] gill-size (entropy = 0.7062740891876007, samples = 26, value = [ 5 21])
│   │   │   ├── [b] e (entropy = 0, samples = 5, value = [5 0])
│   │   │   └── [n] p (entropy = 0, samples = 21, value = [ 0 21])
│   │   ├── [g] e (entropy = 0, samples = 222, value = [222   0])
│   │   ├── [l] cap-color (entropy = 0.7553754125614287, samples = 46, value = [36 10])
│   │   │   ├── [b] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [c] e (entropy = 0, samples = 20, value = [20  0])
│   │   │   ├── [e] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [g] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [n] e (entropy = 0, samples = 16, value = [16  0])
│   │   │   ├── [p] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [r] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [w] p (entropy = 0, samples = 7, value = [0 7])
│   │   │   └── [y] p (entropy = 0, samples = 3, value = [0 3])
│   │   ├── [m] e (entropy = 0, samples = 0, value = [0 0])
│   │   ├── [p] e (entropy = 0, samples = 35, value = [35  0])
│   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   └── [w] e (entropy = 0, samples = 144, value = [144   0])
│   └── [y] e (entropy = 0, samples = 35, value = [35  0])
├── [p] p (entropy = 0, samples = 187, value = [  0 187])
├── [s] p (entropy = 0, samples = 433, value = [  0 433])
└── [y] p (entropy = 0, samples = 433, value = [  0 433])precision    recall  f1-score   supporte       1.00      1.00      1.00      1049p       1.00      1.00      1.00       982accuracy                           1.00      2031macro avg       1.00      1.00      1.00      2031
weighted avg       1.00      1.00      1.00      2031

Carsdata 数据集上的测试

默认属性三分类,忽略有未知值的样本,划分训练集和测试集,约束决策树生长最大深度为 5,节点最小样本数为 3,代码如下:

# cars.csv ignoring samples with unknown value
a = ['mpg', 'cylinders', 'cubicinches', 'hp', 'weightlbs', 'time-to-60', 'year']
X, y = Datasets('cars.csv').getData(a, 'brand', True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt31 = ID3(5, 3)
dt31.fit(X_train, y_train)
dt31.visualize(a)
print()
y_pred = dt31.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [< 191.0] year (entropy = 1.5833913647120852, samples = 105, value = [37 33 35])
│   ├── [< 1981.5] cubicinches (entropy = 1.5558899087683136, samples = 87, value = [37 27 23])
│   │   ├── [< 121.5] cubicinches (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   │   ├── [< 114.0] cubicinches (entropy = 1.4844331941390079, samples = 47, value = [18 21  8])
│   │   │   │   ├── [< 87.0] Japan. (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │   │   └── [>= 87.0] Europe. (entropy = 1.521560239117063, samples = 39, value = [17 14  8])
│   │   │   └── [>= 114.0] weightlbs (entropy = 0.5665095065529053, samples = 15, value = [13  2  0])
│   │   │       ├── [< 2571.0] Europe. (entropy = 0.9709505944546686, samples = 5, value = [3 2 0])
│   │   │       └── [>= 2571.0] Europe. (entropy = 0, samples = 10, value = [10  0  0])
│   │   └── [>= 121.5] weightlbs (entropy = 1.3593308322365363, samples = 25, value = [ 6  4 15])
│   │       ├── [< 3076.5] hp (entropy = 0.9917601481809735, samples = 20, value = [ 1  4 15])
│   │       │   ├── [< 92.5] US. (entropy = 0, samples = 11, value = [ 0  0 11])
│   │       │   └── [>= 92.5] Japan. (entropy = 1.3921472236645345, samples = 9, value = [1 4 4])
│   │       └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
│   └── [>= 1981.5] mpg (entropy = 0.9182958340544896, samples = 18, value = [ 0  6 12])
│       ├── [< 31.3] US. (entropy = 0, samples = 9, value = [0 0 9])
│       └── [>= 31.3] mpg (entropy = 0.9182958340544896, samples = 9, value = [0 6 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 4, value = [0 4 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
└── [>= 191.0] US. (entropy = 0, samples = 87, value = [ 0  0 87])precision    recall  f1-score   supportEurope.       0.50      0.80      0.62        10Japan.       0.83      0.56      0.67        18US.       0.94      0.94      0.94        36accuracy                           0.81        64macro avg       0.76      0.77      0.74        64
weighted avg       0.84      0.81      0.81        64

离散属性和连续属性混合分类测试

根据上文决策树节点划分结果对其中某个属性进行预离散化,相同方式划分训练集和测试集,约束决策树生长参数不变,离散属性和连续属性混合分类测试,代码如下:

# cars.csv with attribute 'cubicinches' discretized 
def find(a, n):def findcur(r):return findcur(r + 1) if r < len(a) and a[r] < n else rreturn findcur(0)
X[:, 2] = np.array([('a', 'b', 'c')[find([121.5, 191.0], i)] for i in X[:, 2]])
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt32 = ID3(5, 3)
dt32.fit(X_train, y_train)
dt32.visualize(a)
y_pred = dt32.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [a] year (entropy = 1.5231103605784926, samples = 74, value = [31 28 15])
│   ├── [< 1981.5] weightlbs (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   ├── [< 2571.0] weightlbs (entropy = 1.4729350396193688, samples = 50, value = [20 22  8])
│   │   │   ├── [< 2271.5] mpg (entropy = 1.5038892873131435, samples = 42, value = [19 15  8])
│   │   │   │   ├── [< 30.25] Europe. (entropy = 1.2640886121123147, samples = 23, value = [15  5  3])
│   │   │   │   └── [>= 30.25] Japan. (entropy = 1.4674579648482995, samples = 19, value = [ 4 10  5])
│   │   │   └── [>= 2271.5] mpg (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │       ├── [< 37.9] Japan. (entropy = 0, samples = 7, value = [0 7 0])
│   │   │       └── [>= 37.9] Europe. (entropy = 0, samples = 1, value = [1 0 0])
│   │   └── [>= 2571.0] cylinders (entropy = 0.41381685030363374, samples = 12, value = [11  1  0])
│   │       ├── [< 3.5] Japan. (entropy = 0, samples = 1, value = [0 1 0])
│   │       └── [>= 3.5] Europe. (entropy = 0, samples = 11, value = [11  0  0])
│   └── [>= 1981.5] mpg (entropy = 0.9798687566511528, samples = 12, value = [0 5 7])
│       ├── [< 31.3] US. (entropy = 0, samples = 4, value = [0 0 4])
│       └── [>= 31.3] mpg (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
├── [b] weightlbs (entropy = 1.2910357498542626, samples = 31, value = [ 6  5 20])
│   ├── [< 3076.5] hp (entropy = 0.9293550115186283, samples = 26, value = [ 1  5 20])
│   │   ├── [< 93.5] US. (entropy = 0, samples = 16, value = [ 0  0 16])
│   │   └── [>= 93.5] time-to-60 (entropy = 1.360964047443681, samples = 10, value = [1 5 4])
│   │       ├── [< 15.5] cylinders (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│   │       │   ├── [< 5.0] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│   │       │   └── [>= 5.0] US. (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   │       └── [>= 15.5] Europe. (entropy = 1.0, samples = 2, value = [1 0 1])
│   └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
└── [c] US. (entropy = 0, samples = 87, value = [ 0  0 87])precision    recall  f1-score   supportEurope.       0.57      0.40      0.47        10Japan.       0.65      0.72      0.68        18US.       0.92      0.94      0.93        36accuracy                           0.80        64macro avg       0.71      0.69      0.70        64
weighted avg       0.79      0.80      0.79        64

Iris 数据集上的测试

默认属性三分类,划分训练集和测试集,限制决策树生长最大深度为 3,代码如下:

# iris dataset
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt41 = ID3(3)
dt41.fit(X_train, y_train)
dt41.visualize(iris['feature_names'])
print()
y_pred = dt41.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

petal length (cm) (entropy = 1.5807197138422102, samples = 112, value = [34 41 37])
├── [< 2.45] 0 (entropy = 0, samples = 34, value = [34  0  0])
└── [>= 2.45] petal width (cm) (entropy = 0.9981021327390103, samples = 78, value = [ 0 41 37])├── [< 1.75] petal length (cm) (entropy = 0.4394969869215134, samples = 44, value = [ 0 40  4])│   ├── [< 4.95] 1 (entropy = 0.17203694935311378, samples = 39, value = [ 0 38  1])│   └── [>= 4.95] 2 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])└── [>= 1.75] petal length (cm) (entropy = 0.19143325481419343, samples = 34, value = [ 0  1 33])├── [< 4.85] 2 (entropy = 0.9182958340544896, samples = 3, value = [0 1 2])└── [>= 4.85] 2 (entropy = 0, samples = 31, value = [ 0  0 31])precision    recall  f1-score   support0       1.00      1.00      1.00        161       1.00      1.00      1.00         92       1.00      1.00      1.00        13accuracy                           1.00        38macro avg       1.00      1.00      1.00        38
weighted avg       1.00      1.00      1.00        38

这篇关于ID3算法 决策树学习 Python实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/660550

相关文章

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

Linux下删除乱码文件和目录的实现方式

《Linux下删除乱码文件和目录的实现方式》:本文主要介绍Linux下删除乱码文件和目录的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux下删除乱码文件和目录方法1方法2总结Linux下删除乱码文件和目录方法1使用ls -i命令找到文件或目录

SpringBoot+EasyExcel实现自定义复杂样式导入导出

《SpringBoot+EasyExcel实现自定义复杂样式导入导出》这篇文章主要为大家详细介绍了SpringBoot如何结果EasyExcel实现自定义复杂样式导入导出功能,文中的示例代码讲解详细,... 目录安装处理自定义导出复杂场景1、列不固定,动态列2、动态下拉3、自定义锁定行/列,添加密码4、合并

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Spring Boot集成Druid实现数据源管理与监控的详细步骤

《SpringBoot集成Druid实现数据源管理与监控的详细步骤》本文介绍如何在SpringBoot项目中集成Druid数据库连接池,包括环境搭建、Maven依赖配置、SpringBoot配置文件... 目录1. 引言1.1 环境准备1.2 Druid介绍2. 配置Druid连接池3. 查看Druid监控

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

Linux在线解压jar包的实现方式

《Linux在线解压jar包的实现方式》:本文主要介绍Linux在线解压jar包的实现方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录linux在线解压jar包解压 jar包的步骤总结Linux在线解压jar包在 Centos 中解压 jar 包可以使用 u

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

c++ 类成员变量默认初始值的实现

《c++类成员变量默认初始值的实现》本文主要介绍了c++类成员变量默认初始值,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录C++类成员变量初始化c++类的变量的初始化在C++中,如果使用类成员变量时未给定其初始值,那么它将被

Python包管理工具pip的升级指南

《Python包管理工具pip的升级指南》本文全面探讨Python包管理工具pip的升级策略,从基础升级方法到高级技巧,涵盖不同操作系统环境下的最佳实践,我们将深入分析pip的工作原理,介绍多种升级方... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核