ID3算法 决策树学习 Python实现

2024-01-30 14:36

本文主要是介绍ID3算法 决策树学习 Python实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

算法流程

输入:约束决策树生长参数(最大深度,节点最小样本数,可选),训练集(特征值离散或连续,标签离散)。
输出:决策树。
过程:每次选择信息增益最大的属性决策分类,直到当前节点样本均为同一类,或者信息增益过小。

信息增益

设样本需分为 K K K 类,当前节点待分类样本中每类样本的个数分别为 n 1 , n 2 , … , n K n_1, n_2, …, n_K n1,n2,,nK,则该节点信息熵为
I ( n 1 , n 2 , … , n K ) = − ∑ i = 1 K n i ∑ j = 1 K n j log ⁡ 2 n i ∑ j = 1 K n j I(n_1, n_2, …, n_K) = -\sum_{i=1}^K \frac{n_i}{\sum_{j=1}^K n_j} \log_2 \frac{n_i}{\sum_{j=1}^K n_j} I(n1,n2,,nK)=i=1Kj=1Knjnilog2j=1Knjni
设属性 A A A v v v 种取值,当前节点样本按属性 A A A 决策分类为 v v v 个子节点,第 i i i 个子节点待分类样本中每类样本的个数分别为 n i 1 , n i 2 , … , n i K n_{i1}, n_{i2}, …, n_{iK} ni1,ni2,,niK,则父节点按属性 A A A 决策分类的类信息熵为
E ( A ) = ∑ i = 1 v ∑ j = 1 K n i j ∑ j = 1 K n j I ( n i 1 , n i 2 , … , n i K ) E(A) = \sum_{i=1}^v \frac{\sum_{j=1}^K n_{ij}}{\sum_{j=1}^K n_j} I(n_{i1}, n_{i2}, …, n_{iK}) E(A)=i=1vj=1Knjj=1KnijI(ni1,ni2,,niK)
由此计算当前节点在属性 A 上的信息增益为
G a i n ( A ) = I ( n 1 , n 2 , … , n K ) − E ( A ) Gain(A) = I(n_1, n_2, …, n_K) - E(A) Gain(A)=I(n1,n2,,nK)E(A)

决策树学习过程中可能出现的问题与解决方法

不相关属性(irrelevant attribute),属性与类分布相独立。此情况下信息增益过小,可以终止决策,将当前节点标签设为最高频类。
不充足属性(inadequate attribute),不同类的样本有完全相同特征。此情况下信息增益为 0 0 0,可以终止决策,将当前节点标签设为最高频类。
未知属性值(unknown value),数据集中某些属性值不确定。可以通过预处理剔除含有未知属性值的样本或属性。
过拟合(overfitting),决策树泛化能力不足。可以约束决策树生长参数。
空分支(empty branch),学习过程中某节点样本数为 0 0 0 v ≥ 3 v≥3 v3 才会发生。可以将当前节点标签设为父节点的最高频类。

参考代码如下(仅能处理离散属性值状态):

import numpy as np
class ID3:def __init__(self, max_depth = 0, min_samples_split = 0):self.max_depth, self.min_samples_split = max_depth, min_samples_splitdef __EI(self, *n):n = np.array([i for i in n if i > 0])if n.shape[0] <= 1:return 0p = n / np.sum(n)return -np.dot(p, np.log2(p))def __Gain(self, A: np.ndarray):return self.__EI(*np.sum(A, axis = 0)) - np.average(np.frompyfunc(self.__EI, A.shape[1], 1)(*A.T), weights = np.sum(A, axis = 1))def fit(self, X: np.ndarray, y):self.DX, (self.Dy, yn) = [np.unique(X[:, i]) for i in range(X.shape[1])], np.unique(y, return_inverse = True)self.Dy: np.ndarrayself.value = []def fitcur(n, h, p = 0):self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))r: np.ndarray = np.unique(y[n])if r.shape[0] == 0: # Empty Branchreturn pelif r.shape[0] == 1:return yn[n[0]]elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfittingreturn np.argmax(np.bincount(yn[n]))else:P = [[n[np.where(X[n, i] == j)[0]] for j in self.DX[i]] for i in range(X.shape[1])]G = [self.__Gain(A) for A in [np.array([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P]]m = np.argmax(G)if(G[m] < 1e-9): # Inadequate attributereturn np.argmax(np.bincount(yn[n]))return (m,) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m])self.tree = fitcur(np.arange(X.shape[0]), 0)def predict(self, X):def precur(n, x):return precur(n[1 + np.where(self.DX[n[0]] == x[n[0]])[0][0]], x) if isinstance(n, tuple) else self.Dy[n]return np.array([precur(self.tree, x) for x in X])def visualize(self, header):i = iter(self.value)def visval():v = next(i)print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')def viscur(n, h, c):for i in h[:-1]:print('%c   ' % ('│' if i else ' '), end = '')if len(h) > 0:print('%c── ' % ('├' if h[-1] else '└'), end = '')print('[%s] ' % c, end = '')if isinstance(n, tuple):print(header[n[0]], end = '')visval()print()for i in range(len(n) - 1):viscur(n[i + 1], h + [i < len(n) - 2], str(self.DX[n[0]][i]))else:print(self.Dy[n], end = '')visval()print()viscur(self.tree, [], '')

连续属性值的离散化

对于某个连续属性,取训练集中所有属性值的相邻两点中点生成界点集,按每个界点将当前节点样本分为 2 2 2 类,算出界点集中最大信息增益的界点。

在上文代码的基础上加以改动,得到能处理连续属性值状态的代码如下:

import numpy as np
class ID3:def __init__(self, max_depth = 0, min_samples_split = 0):self.max_depth, self.min_samples_split = max_depth, min_samples_splitdef __EI(self, *n):n = np.array([i for i in n if i > 0])if n.shape[0] <= 1:return 0p = n / np.sum(n)return -np.dot(p, np.log2(p))def __Gain(self, A: np.ndarray):return self.__EI(*np.sum(A, axis = 0)) - np.average([self.__EI(*a) for a in A], weights = np.sum(A, axis = 1))def fit(self, X: np.ndarray, y):self.c = np.array([(np.all([isinstance(j, (int, float)) for j in i])) for i in X.T])self.DX, (self.Dy, yn) = [np.unique(X[:, i]) if not self.c[i] else None for i in range(X.shape[1])], np.unique(y, return_inverse = True)self.Dy: np.ndarrayself.value = []def Part(n, a):if self.c[a]:u = np.sort(np.unique(X[n, a]))if(u.shape[0] < 2):return Nonev = np.array([(u[i - 1] + u[i]) / 2 for i in range(1, u.shape[0])])P = [[n[np.where(X[n, a] < i)[0]], n[np.where(X[n, a] >= i)[0]]] for i in v]m = np.argmax([self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p]) for p in P])return v[m], P[m]else:return None, [n[np.where(X[n, a] == i)[0]] for i in self.DX[a]]def fitcur(n: np.ndarray, h, p = 0):self.value.append(np.bincount(yn[n], minlength = self.Dy.shape[0]))r: np.ndarray = np.unique(y[n])if r.shape[0] == 0: # Empty Branchreturn pelif r.shape[0] == 1:return yn[n[0]]elif self.max_depth > 0 and h >= self.max_depth or n.shape[0] <= self.min_samples_split: # Overfittingreturn np.argmax(np.bincount(yn[n]))else:P = [Part(n, i) for i in range(X.shape[1])]G = [self.__Gain([[np.where(y[i] == j)[0].shape[0] for j in self.Dy] for i in p[1]]) if p != None else 0 for p in P]m = np.argmax(G)if(G[m] < 1e-9): # Inadequate attributereturn np.argmax(np.bincount(yn[n]))return ((m, P[m][0]) if self.c[m] else (m,)) + tuple(fitcur(i, h + 1, np.argmax(np.bincount(yn[n]))) for i in P[m][1])self.tree = fitcur(np.arange(X.shape[0]), 0)def predict(self, X):def precur(n, x):return precur(n[(2 if x[n[0]] < n[1] else 3) if self.c[n[0]] else (1 + np.where(self.DX[n[0]] == x[n[0]])[0][0])], x) if isinstance(n, tuple) else self.Dy[n]return np.array([precur(self.tree, x) for x in X])def visualize(self, header):i = iter(self.value)def visval():v = next(i)print(' (entropy = {}, samples = {}, value = {})'.format(self.__EI(*v), np.sum(v), v), end = '')def viscur(n, h, c):for i in h[:-1]:print('%c   ' % ('│' if i else ' '), end = '')if len(h) > 0:print('%c── ' % ('├' if h[-1] else '└'), end = '')print('[%s] ' % c, end = '')if isinstance(n, tuple):print(header[n[0]], end = '')visval()print()if self.c[n[0]]:for i in range(2):viscur(n[2 + i], h + [i < 1], ('< ', '>= ')[i] + str(n[1]))else:for i in range(len(n) - 1):viscur(n[1 + i], h + [i < len(n) - 2], str(self.DX[n[0]][i]))else:print(self.Dy[n], end = '')visval()print()viscur(self.tree, [], '')

实验测试

实验使用数据集如下:
Play tennis 数据集(来源:kaggle):离散属性
Mushroom classification 数据集(来源:kaggle):离散属性
Carsdata 数据集(来源:kaggle):连续属性
Iris 数据集(来源:sklearn.datasets):连续属性

其中 play_tennis.csv 内容如下:

dayoutlooktemphumiditywindplay
D1SunnyHotHighWeakNo
D2SunnyHotHighStrongNo
D3OvercastHotHighWeakYes
D4RainMildHighWeakYes
D5RainCoolNormalWeakYes
D6RainCoolNormalStrongNo
D7OvercastCoolNormalStrongYes
D8SunnyMildHighWeakNo
D9SunnyCoolNormalWeakYes
D10RainMildNormalWeakYes
D11SunnyMildNormalStrongYes
D12OvercastMildHighStrongYes
D13OvercastHotNormalWeakYes
D14RainMildHighStrongNo

Play tennis 数据集上的测试

默认属性二分类测试,代码如下:

import pandas as pd
class Datasets:def __init__(self, fn):self.df = pd.read_csv('Datasets\\%s' % fn).map(lambda x: x.strip() if isinstance(x, str) else x)self.df.rename(columns = lambda x: x.strip(), inplace = True)def getData(self, DX, Dy, drop = False):dfn = self.df.loc[~self.df.eq('').any(axis = 1)].apply(pd.to_numeric, errors = 'ignore') if drop else self.dfreturn dfn[DX].to_numpy(dtype = np.object_), dfn[Dy].to_numpy(dtype = np.object_)# play_tennis.csv
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
dt11 = ID3()
dt11.fit(X, y)
dt11.visualize(a)
print()

结果如下:

outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])├── [High] No (entropy = 0, samples = 3, value = [3 0])└── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])

不充足属性测试

更换属性三分类,不充足属性测试,代码如下:

# play_tennis.csv for inadequate attribute test and class > 2
a = ['temp', 'humidity', 'wind', 'play']
X, y = Datasets('play_tennis.csv').getData(a, 'outlook')
dt12 = ID3(10)
dt12.fit(X, y)
dt12.visualize(a)
print(dt12.predict([['Cool', 'Normal', 'Weak', 'Yes']]))
print()

结果如下:

play (entropy = 1.5774062828523454, samples = 14, value = [4 5 5])
├── [No] temp (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   ├── [Cool] Rain (entropy = 0, samples = 1, value = [0 1 0])
│   ├── [Hot] Sunny (entropy = 0, samples = 2, value = [0 0 2])
│   └── [Mild] wind (entropy = 1.0, samples = 2, value = [0 1 1])
│       ├── [Strong] Rain (entropy = 0, samples = 1, value = [0 1 0])
│       └── [Weak] Sunny (entropy = 0, samples = 1, value = [0 0 1])
└── [Yes] temp (entropy = 1.5304930567574826, samples = 9, value = [4 3 2])├── [Cool] wind (entropy = 1.584962500721156, samples = 3, value = [1 1 1])│   ├── [Strong] Overcast (entropy = 0, samples = 1, value = [1 0 0])│   └── [Weak] Rain (entropy = 1.0, samples = 2, value = [0 1 1])├── [Hot] Overcast (entropy = 0, samples = 2, value = [2 0 0])└── [Mild] wind (entropy = 1.5, samples = 4, value = [1 2 1])├── [Strong] humidity (entropy = 1.0, samples = 2, value = [1 0 1])│   ├── [High] Overcast (entropy = 0, samples = 1, value = [1 0 0])│   └── [Normal] Sunny (entropy = 0, samples = 1, value = [0 0 1])└── [Weak] Rain (entropy = 0, samples = 2, value = [0 2 0])
['Rain']

空分支测试

默认属性二分类,修改部分数据,空分支测试,代码如下:

# play_tennis.csv modified to generate empty branch
a = ['outlook', 'temp', 'humidity', 'wind']
X, y = Datasets('play_tennis.csv').getData(a, 'play')
X[2, 2], X[13, 2] = 'Low', 'Low'
dt13 = ID3()
dt13.fit(X, y)
dt13.visualize(a)
print(dt13.predict([['Sunny', 'Hot', 'Low', 'Weak']]))
print()

结果如下:

outlook (entropy = 0.9402859586706311, samples = 14, value = [5 9])
├── [Overcast] Yes (entropy = 0, samples = 4, value = [0 4])
├── [Rain] wind (entropy = 0.9709505944546686, samples = 5, value = [2 3])
│   ├── [Strong] No (entropy = 0, samples = 2, value = [2 0])
│   └── [Weak] Yes (entropy = 0, samples = 3, value = [0 3])
└── [Sunny] humidity (entropy = 0.9709505944546686, samples = 5, value = [3 2])├── [High] No (entropy = 0, samples = 3, value = [3 0])├── [Low] No (entropy = 0, samples = 0, value = [0 0])└── [Normal] Yes (entropy = 0, samples = 2, value = [0 2])
['No']

Mushroom classification 数据集上的测试

默认属性二分类,忽略有未知值的属性,划分训练集和测试集,代码如下:

# mushrooms.csv ignoring attribute 'stalk-root' with unknown value
a = ['cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor', 'gill-attachment', 'gill-spacing', 'gill-size','gill-color', 'stalk-shape', 'stalk-surface-above-ring', 'stalk-surface-below-ring', 'stalk-color-above-ring', 'stalk-color-below-ring', 'veil-type','veil-color', 'ring-number', 'ring-type', 'spore-print-color', 'population', 'habitat']
X, y = Datasets('mushrooms.csv').getData(a, 'class')
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt21 = ID3()
dt21.fit(X_train, y_train)
dt21.visualize(a)
print()
y_pred = dt21.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

odor (entropy = 0.9990161113058208, samples = 6093, value = [3159 2934])
├── [a] e (entropy = 0, samples = 298, value = [298   0])
├── [c] p (entropy = 0, samples = 138, value = [  0 138])
├── [f] p (entropy = 0, samples = 1636, value = [   0 1636])
├── [l] e (entropy = 0, samples = 297, value = [297   0])
├── [m] p (entropy = 0, samples = 26, value = [ 0 26])
├── [n] spore-print-color (entropy = 0.19751069442516636, samples = 2645, value = [2564   81])
│   ├── [b] e (entropy = 0, samples = 32, value = [32  0])
│   ├── [h] e (entropy = 0, samples = 35, value = [35  0])
│   ├── [k] e (entropy = 0, samples = 974, value = [974   0])
│   ├── [n] e (entropy = 0, samples = 1013, value = [1013    0])
│   ├── [o] e (entropy = 0, samples = 33, value = [33  0])
│   ├── [r] p (entropy = 0, samples = 50, value = [ 0 50])
│   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   ├── [w] habitat (entropy = 0.34905151737109524, samples = 473, value = [442  31])
│   │   ├── [d] gill-size (entropy = 0.7062740891876007, samples = 26, value = [ 5 21])
│   │   │   ├── [b] e (entropy = 0, samples = 5, value = [5 0])
│   │   │   └── [n] p (entropy = 0, samples = 21, value = [ 0 21])
│   │   ├── [g] e (entropy = 0, samples = 222, value = [222   0])
│   │   ├── [l] cap-color (entropy = 0.7553754125614287, samples = 46, value = [36 10])
│   │   │   ├── [b] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [c] e (entropy = 0, samples = 20, value = [20  0])
│   │   │   ├── [e] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [g] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [n] e (entropy = 0, samples = 16, value = [16  0])
│   │   │   ├── [p] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [r] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   │   ├── [w] p (entropy = 0, samples = 7, value = [0 7])
│   │   │   └── [y] p (entropy = 0, samples = 3, value = [0 3])
│   │   ├── [m] e (entropy = 0, samples = 0, value = [0 0])
│   │   ├── [p] e (entropy = 0, samples = 35, value = [35  0])
│   │   ├── [u] e (entropy = 0, samples = 0, value = [0 0])
│   │   └── [w] e (entropy = 0, samples = 144, value = [144   0])
│   └── [y] e (entropy = 0, samples = 35, value = [35  0])
├── [p] p (entropy = 0, samples = 187, value = [  0 187])
├── [s] p (entropy = 0, samples = 433, value = [  0 433])
└── [y] p (entropy = 0, samples = 433, value = [  0 433])precision    recall  f1-score   supporte       1.00      1.00      1.00      1049p       1.00      1.00      1.00       982accuracy                           1.00      2031macro avg       1.00      1.00      1.00      2031
weighted avg       1.00      1.00      1.00      2031

Carsdata 数据集上的测试

默认属性三分类,忽略有未知值的样本,划分训练集和测试集,约束决策树生长最大深度为 5,节点最小样本数为 3,代码如下:

# cars.csv ignoring samples with unknown value
a = ['mpg', 'cylinders', 'cubicinches', 'hp', 'weightlbs', 'time-to-60', 'year']
X, y = Datasets('cars.csv').getData(a, 'brand', True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt31 = ID3(5, 3)
dt31.fit(X_train, y_train)
dt31.visualize(a)
print()
y_pred = dt31.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [< 191.0] year (entropy = 1.5833913647120852, samples = 105, value = [37 33 35])
│   ├── [< 1981.5] cubicinches (entropy = 1.5558899087683136, samples = 87, value = [37 27 23])
│   │   ├── [< 121.5] cubicinches (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   │   ├── [< 114.0] cubicinches (entropy = 1.4844331941390079, samples = 47, value = [18 21  8])
│   │   │   │   ├── [< 87.0] Japan. (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │   │   └── [>= 87.0] Europe. (entropy = 1.521560239117063, samples = 39, value = [17 14  8])
│   │   │   └── [>= 114.0] weightlbs (entropy = 0.5665095065529053, samples = 15, value = [13  2  0])
│   │   │       ├── [< 2571.0] Europe. (entropy = 0.9709505944546686, samples = 5, value = [3 2 0])
│   │   │       └── [>= 2571.0] Europe. (entropy = 0, samples = 10, value = [10  0  0])
│   │   └── [>= 121.5] weightlbs (entropy = 1.3593308322365363, samples = 25, value = [ 6  4 15])
│   │       ├── [< 3076.5] hp (entropy = 0.9917601481809735, samples = 20, value = [ 1  4 15])
│   │       │   ├── [< 92.5] US. (entropy = 0, samples = 11, value = [ 0  0 11])
│   │       │   └── [>= 92.5] Japan. (entropy = 1.3921472236645345, samples = 9, value = [1 4 4])
│   │       └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
│   └── [>= 1981.5] mpg (entropy = 0.9182958340544896, samples = 18, value = [ 0  6 12])
│       ├── [< 31.3] US. (entropy = 0, samples = 9, value = [0 0 9])
│       └── [>= 31.3] mpg (entropy = 0.9182958340544896, samples = 9, value = [0 6 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 4, value = [0 4 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
└── [>= 191.0] US. (entropy = 0, samples = 87, value = [ 0  0 87])precision    recall  f1-score   supportEurope.       0.50      0.80      0.62        10Japan.       0.83      0.56      0.67        18US.       0.94      0.94      0.94        36accuracy                           0.81        64macro avg       0.76      0.77      0.74        64
weighted avg       0.84      0.81      0.81        64

离散属性和连续属性混合分类测试

根据上文决策树节点划分结果对其中某个属性进行预离散化,相同方式划分训练集和测试集,约束决策树生长参数不变,离散属性和连续属性混合分类测试,代码如下:

# cars.csv with attribute 'cubicinches' discretized 
def find(a, n):def findcur(r):return findcur(r + 1) if r < len(a) and a[r] < n else rreturn findcur(0)
X[:, 2] = np.array([('a', 'b', 'c')[find([121.5, 191.0], i)] for i in X[:, 2]])
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt32 = ID3(5, 3)
dt32.fit(X_train, y_train)
dt32.visualize(a)
y_pred = dt32.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

cubicinches (entropy = 1.3101461692119258, samples = 192, value = [ 37  33 122])
├── [a] year (entropy = 1.5231103605784926, samples = 74, value = [31 28 15])
│   ├── [< 1981.5] weightlbs (entropy = 1.4119058166561587, samples = 62, value = [31 23  8])
│   │   ├── [< 2571.0] weightlbs (entropy = 1.4729350396193688, samples = 50, value = [20 22  8])
│   │   │   ├── [< 2271.5] mpg (entropy = 1.5038892873131435, samples = 42, value = [19 15  8])
│   │   │   │   ├── [< 30.25] Europe. (entropy = 1.2640886121123147, samples = 23, value = [15  5  3])
│   │   │   │   └── [>= 30.25] Japan. (entropy = 1.4674579648482995, samples = 19, value = [ 4 10  5])
│   │   │   └── [>= 2271.5] mpg (entropy = 0.5435644431995964, samples = 8, value = [1 7 0])
│   │   │       ├── [< 37.9] Japan. (entropy = 0, samples = 7, value = [0 7 0])
│   │   │       └── [>= 37.9] Europe. (entropy = 0, samples = 1, value = [1 0 0])
│   │   └── [>= 2571.0] cylinders (entropy = 0.41381685030363374, samples = 12, value = [11  1  0])
│   │       ├── [< 3.5] Japan. (entropy = 0, samples = 1, value = [0 1 0])
│   │       └── [>= 3.5] Europe. (entropy = 0, samples = 11, value = [11  0  0])
│   └── [>= 1981.5] mpg (entropy = 0.9798687566511528, samples = 12, value = [0 5 7])
│       ├── [< 31.3] US. (entropy = 0, samples = 4, value = [0 0 4])
│       └── [>= 31.3] mpg (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│           ├── [< 33.2] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│           └── [>= 33.2] time-to-60 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│               ├── [< 16.5] US. (entropy = 0, samples = 2, value = [0 0 2])
│               └── [>= 16.5] Japan. (entropy = 0.9182958340544896, samples = 3, value = [0 2 1])
├── [b] weightlbs (entropy = 1.2910357498542626, samples = 31, value = [ 6  5 20])
│   ├── [< 3076.5] hp (entropy = 0.9293550115186283, samples = 26, value = [ 1  5 20])
│   │   ├── [< 93.5] US. (entropy = 0, samples = 16, value = [ 0  0 16])
│   │   └── [>= 93.5] time-to-60 (entropy = 1.360964047443681, samples = 10, value = [1 5 4])
│   │       ├── [< 15.5] cylinders (entropy = 0.954434002924965, samples = 8, value = [0 5 3])
│   │       │   ├── [< 5.0] Japan. (entropy = 0, samples = 3, value = [0 3 0])
│   │       │   └── [>= 5.0] US. (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])
│   │       └── [>= 15.5] Europe. (entropy = 1.0, samples = 2, value = [1 0 1])
│   └── [>= 3076.5] Europe. (entropy = 0, samples = 5, value = [5 0 0])
└── [c] US. (entropy = 0, samples = 87, value = [ 0  0 87])precision    recall  f1-score   supportEurope.       0.57      0.40      0.47        10Japan.       0.65      0.72      0.68        18US.       0.92      0.94      0.93        36accuracy                           0.80        64macro avg       0.71      0.69      0.70        64
weighted avg       0.79      0.80      0.79        64

Iris 数据集上的测试

默认属性三分类,划分训练集和测试集,限制决策树生长最大深度为 3,代码如下:

# iris dataset
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 20231218)
dt41 = ID3(3)
dt41.fit(X_train, y_train)
dt41.visualize(iris['feature_names'])
print()
y_pred = dt41.predict(X_test)
print(classification_report(y_test, y_pred))

结果如下:

petal length (cm) (entropy = 1.5807197138422102, samples = 112, value = [34 41 37])
├── [< 2.45] 0 (entropy = 0, samples = 34, value = [34  0  0])
└── [>= 2.45] petal width (cm) (entropy = 0.9981021327390103, samples = 78, value = [ 0 41 37])├── [< 1.75] petal length (cm) (entropy = 0.4394969869215134, samples = 44, value = [ 0 40  4])│   ├── [< 4.95] 1 (entropy = 0.17203694935311378, samples = 39, value = [ 0 38  1])│   └── [>= 4.95] 2 (entropy = 0.9709505944546686, samples = 5, value = [0 2 3])└── [>= 1.75] petal length (cm) (entropy = 0.19143325481419343, samples = 34, value = [ 0  1 33])├── [< 4.85] 2 (entropy = 0.9182958340544896, samples = 3, value = [0 1 2])└── [>= 4.85] 2 (entropy = 0, samples = 31, value = [ 0  0 31])precision    recall  f1-score   support0       1.00      1.00      1.00        161       1.00      1.00      1.00         92       1.00      1.00      1.00        13accuracy                           1.00        38macro avg       1.00      1.00      1.00        38
weighted avg       1.00      1.00      1.00        38

这篇关于ID3算法 决策树学习 Python实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/660550

相关文章

C++使用栈实现括号匹配的代码详解

《C++使用栈实现括号匹配的代码详解》在编程中,括号匹配是一个常见问题,尤其是在处理数学表达式、编译器解析等任务时,栈是一种非常适合处理此类问题的数据结构,能够精确地管理括号的匹配问题,本文将通过C+... 目录引言问题描述代码讲解代码解析栈的状态表示测试总结引言在编程中,括号匹配是一个常见问题,尤其是在

Python调用Orator ORM进行数据库操作

《Python调用OratorORM进行数据库操作》OratorORM是一个功能丰富且灵活的PythonORM库,旨在简化数据库操作,它支持多种数据库并提供了简洁且直观的API,下面我们就... 目录Orator ORM 主要特点安装使用示例总结Orator ORM 是一个功能丰富且灵活的 python O

Java实现检查多个时间段是否有重合

《Java实现检查多个时间段是否有重合》这篇文章主要为大家详细介绍了如何使用Java实现检查多个时间段是否有重合,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录流程概述步骤详解China编程步骤1:定义时间段类步骤2:添加时间段步骤3:检查时间段是否有重合步骤4:输出结果示例代码结语作

Python使用国内镜像加速pip安装的方法讲解

《Python使用国内镜像加速pip安装的方法讲解》在Python开发中,pip是一个非常重要的工具,用于安装和管理Python的第三方库,然而,在国内使用pip安装依赖时,往往会因为网络问题而导致速... 目录一、pip 工具简介1. 什么是 pip?2. 什么是 -i 参数?二、国内镜像源的选择三、如何

使用C++实现链表元素的反转

《使用C++实现链表元素的反转》反转链表是链表操作中一个经典的问题,也是面试中常见的考题,本文将从思路到实现一步步地讲解如何实现链表的反转,帮助初学者理解这一操作,我们将使用C++代码演示具体实现,同... 目录问题定义思路分析代码实现带头节点的链表代码讲解其他实现方式时间和空间复杂度分析总结问题定义给定

Java覆盖第三方jar包中的某一个类的实现方法

《Java覆盖第三方jar包中的某一个类的实现方法》在我们日常的开发中,经常需要使用第三方的jar包,有时候我们会发现第三方的jar包中的某一个类有问题,或者我们需要定制化修改其中的逻辑,那么应该如何... 目录一、需求描述二、示例描述三、操作步骤四、验证结果五、实现原理一、需求描述需求描述如下:需要在

如何使用Java实现请求deepseek

《如何使用Java实现请求deepseek》这篇文章主要为大家详细介绍了如何使用Java实现请求deepseek功能,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录1.deepseek的api创建2.Java实现请求deepseek2.1 pom文件2.2 json转化文件2.2

python使用fastapi实现多语言国际化的操作指南

《python使用fastapi实现多语言国际化的操作指南》本文介绍了使用Python和FastAPI实现多语言国际化的操作指南,包括多语言架构技术栈、翻译管理、前端本地化、语言切换机制以及常见陷阱和... 目录多语言国际化实现指南项目多语言架构技术栈目录结构翻译工作流1. 翻译数据存储2. 翻译生成脚本

如何通过Python实现一个消息队列

《如何通过Python实现一个消息队列》这篇文章主要为大家详细介绍了如何通过Python实现一个简单的消息队列,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录如何通过 python 实现消息队列如何把 http 请求放在队列中执行1. 使用 queue.Queue 和 reque

Python如何实现PDF隐私信息检测

《Python如何实现PDF隐私信息检测》随着越来越多的个人信息以电子形式存储和传输,确保这些信息的安全至关重要,本文将介绍如何使用Python检测PDF文件中的隐私信息,需要的可以参考下... 目录项目背景技术栈代码解析功能说明运行结php果在当今,数据隐私保护变得尤为重要。随着越来越多的个人信息以电子形