手写决策树ID3算法(python)

2024-08-31 03:32

本文主要是介绍手写决策树ID3算法(python),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

决策数(Decision Tree)在机器学习中也是比较常见的一种算法,属于监督学习中的一种。看字面意思应该也比较容易理解,相比其他算法比如支持向量机(SVM)或神经网络,似乎决策树感觉“亲切”许多。

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失值不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配的问题。
使用数据类型:数值型和标称型。
简单介绍完毕,让我们来通过一个例子让决策树“原形毕露”。

一天,老师问了个问题,只根据头发和声音怎么判断一位同学的性别。
为了解决这个问题,同学们马上简单的统计了7位同学的相关特征,数据如下:

机智的同学A想了想,先根据头发判断,若判断不出,再根据声音判断,于是画了一幅图,如下:

于是,一个简单、直观的决策树就这么出来了。头发长、声音粗就是男生;头发长、声音细就是女生;头发短、声音粗是男生;头发短、声音细是女生。
原来机器学习中决策树就这玩意,这也太简单了吧。。。
这时又蹦出个同学B,想先根据声音判断,然后再根据头发来判断,如是大手一挥也画了个决策树:

同学B的决策树:首先判断声音,声音细,就是女生;声音粗、头发长是男生;声音粗、头发长是女生。

那么问题来了:同学A和同学B谁的决策树好些?计算机做决策树的时候,面对多个特征,该如何选哪个特征为最佳的划分特征?

划分数据集的大原则是:将无序的数据变得更加有序。
我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。于是我们这么想,如果我们能测量数据的复杂度,对比按不同特征分类后的数据复杂度,若按某一特征分类后复杂度减少的更多,那么这个特征即为最佳分类特征。
Claude Shannon 定义了熵(entropy)和信息增益(information gain)。
用熵来表示信息的复杂度,熵越大,则信息越复杂。公式如下:

信息增益(information gain),表示两个信息熵的差值。
首先计算未分类前的熵,总共有8位同学,男生3位,女生5位。
熵(总)=-3/8log2(3/8)-5/8log2(5/8)=0.9544
接着分别计算同学A和同学B分类后信息熵。
同学A首先按头发分类,分类后的结果为:长头发中有1男3女。短头发中有2男2女。
熵(同学A长发)=-1/4log2(1/4)-3/4log2(3/4)=0.8113
熵(同学A短发)=-2/4log2(2/4)-2/4log2(2/4)=1
熵(同学A)=4/80.8113+4/81=0.9057
信息增益(同学A)=熵(总)-熵(同学A)=0.9544-0.9057=0.0487
同理,按同学B的方法,首先按声音特征来分,分类后的结果为:声音粗中有3男3女。声音细中有0男2女。
熵(同学B声音粗)=-3/6log2(3/6)-3/6log2(3/6)=1
熵(同学B声音粗)=-2/2log2(2/2)=0
熵(同学B)=6/81+2/8*0=0.75
信息增益(同学B)=熵(总)-熵(同学B)=0.9544-0.75=0.2087

按同学B的方法,先按声音特征分类,信息增益更大,区分样本的能力更强,更具有代表性。
以上就是决策树ID3算法的核心思想。
接下来用python代码来实现ID3算法:
 

#决策树ID3算法
from math import log
import operatordef calcShannonEnt(dataSet):  # 计算数据的熵(entropy)numEntries=len(dataSet)  # 数据条数labelCounts={}for featVec in dataSet:currentLabel=featVec[-1] # 每行数据的最后一个字(类别)if currentLabel not in labelCounts.keys():labelCounts[currentLabel]=0labelCounts[currentLabel]+=1  # 统计有多少个类以及每个类的数量shannonEnt=0for key in labelCounts:prob=float(labelCounts[key])/numEntries # 计算单个类的熵值shannonEnt-=prob*log(prob,2) # 累加每个类的熵值return shannonEntdef createDataSet1():    # 创造示例数据dataSet = [['长', '粗', '男'],['短', '粗', '男'],['短', '粗', '男'],['长', '细', '女'],['短', '细', '女'],['短', '粗', '女'],['长', '粗', '女'],['长', '粗', '女']]labels = ['头发','声音']  #两个特征return dataSet,labelsdef splitDataSet(dataSet,axis,value): # 按某个特征分类后的数据retDataSet=[]for featVec in dataSet:if featVec[axis]==value:reducedFeatVec =featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSetdef chooseBestFeatureToSplit(dataSet):  # 选择最优的分类特征numFeatures = len(dataSet[0])-1baseEntropy = calcShannonEnt(dataSet)  # 原始的熵bestInfoGain = 0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0for value in uniqueVals:subDataSet = splitDataSet(dataSet,i,value)prob =len(subDataSet)/float(len(dataSet))newEntropy +=prob*calcShannonEnt(subDataSet)  # 按特征分类后的熵infoGain = baseEntropy - newEntropy  # 原始熵与按特征分类后的熵的差值if (infoGain>bestInfoGain):   # 若按某特征划分后,熵值减少的最大,则次特征为最优分类特征bestInfoGain=infoGainbestFeature = ireturn bestFeaturedef majorityCnt(classList):    #按分类后类别数量排序,比如:最后分类为2男1女,则判定为男;classCount={}for vote in classList:if vote not in classCount.keys():classCount[vote]=0classCount[vote]+=1sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)return sortedClassCount[0][0]def createTree(dataSet,labels):classList=[example[-1] for example in dataSet]  # 类别:男或女if classList.count(classList[0])==len(classList):return classList[0]if len(dataSet[0])==1:return majorityCnt(classList)bestFeat=chooseBestFeatureToSplit(dataSet) #选择最优特征bestFeatLabel=labels[bestFeat]myTree={bestFeatLabel:{}} #分类结果以字典形式保存del(labels[bestFeat])featValues=[example[bestFeat] for example in dataSet]uniqueVals=set(featValues)for value in uniqueVals:subLabels=labels[:]myTree[bestFeatLabel][value]=createTree(splitDataSet\(dataSet,bestFeat,value),subLabels)return myTreedef predict(mytree, tips, list1):res = []for item in list1:tmp_tree = mytreeiter = tmp_tree.__iter__()     while 1:try:key = iter.__next__()if isinstance(key, str) and (key == "男" or key == "女"):res.append(key)breakv = tmp_tree[key]index = tips[key]item_res = item[index]tmp_tree = v[item_res]iter = tmp_tree.__iter__()except StopIteration:breakreturn resif __name__=='__main__':dataSet, labels=createDataSet1()  # 创造示列数据mytree = createTree(dataSet, labels)print(mytree)  # 输出决策树模型结果#预测tips = {"头发":0, "声音":1}res = predict(mytree, tips, [['长', '粗'], ['短', '粗']])print(res)

 

 

 

这篇关于手写决策树ID3算法(python)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1122762

相关文章

不懂推荐算法也能设计推荐系统

本文以商业化应用推荐为例,告诉我们不懂推荐算法的产品,也能从产品侧出发, 设计出一款不错的推荐系统。 相信很多新手产品,看到算法二字,多是懵圈的。 什么排序算法、最短路径等都是相对传统的算法(注:传统是指科班出身的产品都会接触过)。但对于推荐算法,多数产品对着网上搜到的资源,都会无从下手。特别当某些推荐算法 和 “AI”扯上关系后,更是加大了理解的难度。 但,不了解推荐算法,就无法做推荐系

python: 多模块(.py)中全局变量的导入

文章目录 global关键字可变类型和不可变类型数据的内存地址单模块(单个py文件)的全局变量示例总结 多模块(多个py文件)的全局变量from x import x导入全局变量示例 import x导入全局变量示例 总结 global关键字 global 的作用范围是模块(.py)级别: 当你在一个模块(文件)中使用 global 声明变量时,这个变量只在该模块的全局命名空

康拓展开(hash算法中会用到)

康拓展开是一个全排列到一个自然数的双射(也就是某个全排列与某个自然数一一对应) 公式: X=a[n]*(n-1)!+a[n-1]*(n-2)!+...+a[i]*(i-1)!+...+a[1]*0! 其中,a[i]为整数,并且0<=a[i]<i,1<=i<=n。(a[i]在不同应用中的含义不同); 典型应用: 计算当前排列在所有由小到大全排列中的顺序,也就是说求当前排列是第

csu 1446 Problem J Modified LCS (扩展欧几里得算法的简单应用)

这是一道扩展欧几里得算法的简单应用题,这题是在湖南多校训练赛中队友ac的一道题,在比赛之后请教了队友,然后自己把它a掉 这也是自己独自做扩展欧几里得算法的题目 题意:把题意转变下就变成了:求d1*x - d2*y = f2 - f1的解,很明显用exgcd来解 下面介绍一下exgcd的一些知识点:求ax + by = c的解 一、首先求ax + by = gcd(a,b)的解 这个

综合安防管理平台LntonAIServer视频监控汇聚抖动检测算法优势

LntonAIServer视频质量诊断功能中的抖动检测是一个专门针对视频稳定性进行分析的功能。抖动通常是指视频帧之间的不必要运动,这种运动可能是由于摄像机的移动、传输中的错误或编解码问题导致的。抖动检测对于确保视频内容的平滑性和观看体验至关重要。 优势 1. 提高图像质量 - 清晰度提升:减少抖动,提高图像的清晰度和细节表现力,使得监控画面更加真实可信。 - 细节增强:在低光条件下,抖

【数据结构】——原来排序算法搞懂这些就行,轻松拿捏

前言:快速排序的实现最重要的是找基准值,下面让我们来了解如何实现找基准值 基准值的注释:在快排的过程中,每一次我们要取一个元素作为枢纽值,以这个数字来将序列划分为两部分。 在此我们采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,将中间数作为枢纽值。 快速排序实现主框架: //快速排序 void QuickSort(int* arr, int left, int rig

【Python编程】Linux创建虚拟环境并配置与notebook相连接

1.创建 使用 venv 创建虚拟环境。例如,在当前目录下创建一个名为 myenv 的虚拟环境: python3 -m venv myenv 2.激活 激活虚拟环境使其成为当前终端会话的活动环境。运行: source myenv/bin/activate 3.与notebook连接 在虚拟环境中,使用 pip 安装 Jupyter 和 ipykernel: pip instal

poj 3974 and hdu 3068 最长回文串的O(n)解法(Manacher算法)

求一段字符串中的最长回文串。 因为数据量比较大,用原来的O(n^2)会爆。 小白上的O(n^2)解法代码:TLE啦~ #include<stdio.h>#include<string.h>const int Maxn = 1000000;char s[Maxn];int main(){char e[] = {"END"};while(scanf("%s", s) != EO

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

秋招最新大模型算法面试,熬夜都要肝完它

💥大家在面试大模型LLM这个板块的时候,不知道面试完会不会复盘、总结,做笔记的习惯,这份大模型算法岗面试八股笔记也帮助不少人拿到过offer ✨对于面试大模型算法工程师会有一定的帮助,都附有完整答案,熬夜也要看完,祝大家一臂之力 这份《大模型算法工程师面试题》已经上传CSDN,还有完整版的大模型 AI 学习资料,朋友们如果需要可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费