掰开揉碎机器学习系列-决策树(1)-ID3决策树

2024-03-16 12:30

本文主要是介绍掰开揉碎机器学习系列-决策树(1)-ID3决策树,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、决策树的理论依据:

1、熵的概念:

熵代表了数据分布的"稳定程度"(书上写的所谓纯度),或者说是"分布的离散程度"。用掰开揉碎的方式解释如下:
如以下数据:
技术能力 积极度 年龄 前途

6 8 old normal

8 9 old yes

3 3 old no

7 5 old normal

7 7 young normal

7 6 old normal

8 5 old normal

2 2 old no

7 5 old normal

6 6 young normal

7 4 old normal

8 4 old normal

4 3 old no

5 4 old no

5 4 old no

6 4 old no

6 3 old normal

7 8 young yes

6 8 young yes

6 5 old no

上面是包括我在内的20个同部门员工的一个真实训练样本,分为3个维度考量,分别是技术能力(0-10整型)、工作积极度(0-10整型)、年龄(bool型),样本的结果是前途(分为好中坏)。首先看看如何计算熵:

熵的计算公式是sum(0 - p(i) * log(p(i) * 2)),这个公式来源于香农,具体为什么我暂时无法解释(待续)。计算过程即: 0 - (p(yes)* log(p(yes), 2) - (p(no) * log(p(no), 2) - (p(normal), log(p(normal) * 2)= 0 -

(0.15 * log(0.15, 2)) - (0.35 * log(0.35, 2)) - (0.5 *log(0.5, 2)) = 1.44064544962

15%的人有前途,35%的人一般,50%的人没有前途,前景无望的人居多数,但依然不乏有一点前途及一小撮比较有前途的人。熵为1.44064544962。

这里要明确的发现,熵和特征分布无关,只和结果取值分布有关

现在,部门裁员了,如果样本简化为:

6 8 old normal

8 9 old yes

3 3 old no

再次计算熵,0- (1/3 * log(1/3, 2)) - (1/3 * log(1/3, 2)) - (1/3 * log(1/3, 2)) = 0- log(1/3, 2) = 1.58496250072。裁员后,分布变的更加复杂了,前景好坏的人三分天下。熵变大了。

后来,有点前景的受不了都走了,部门换来两个毫无前途的庸人,样本变为了:

3 3 old no

5 4 old no

5 4 old no

再次计算熵,0- (1 * log(1, 2)) - (1 * log(1, 2)) - (1 * log(1, 2)) = 0 - log(1,2) = 0

可见,部门现在情况很稳定,完全都是前景无望的庸人了。熵降到冰点0了。

 

现在可以总结:

熵是什么?熵反应了当前样本的概率分布的稳定性,如果概率分布非常"分散""平均",什么样的情况都有而且分布平均,那么熵会比较大,相反会更小

2、信息增益

         谈到信息增益,必须首先看2.1。

         2.1、决策树大概是什么样子的:

                  

                  蓝色代表了特征,红色代表了结果。

那么,很可能下面这样的一个样本,会训练出上面这样的决策树:

                           天气 老婆是否在家 老婆是否例假 采取的行动(结果)

                           好 不在 有 跟小三出去

好 不在 没有 跟小三出去

好 在 有 跟老婆出去

好 在 没有 跟老婆出去

不好 不在 有 玩游戏

不好 不在 没有 玩游戏

不好 在 有 玩游戏

不好 在 没有 啪啪啪

         2.2、决策树希望是什么样子的:

                   决策树,作为由训练样本生成的模型,要尽力体现共性,避免过拟合。对于决策树的树形结构来说就要避免过多的分支,即避免过拟合。关于过拟合,在接下来的回归算法文章中还会不停的强调。

                   决策树如何避免过拟合?

1、 从决策树的正常创建过程来说:

树的每一层的根节点的特征不是随便取的,要根据当前这一层,样本数据以哪个特征作为这一层的根节点,样本数据的概率分布更难体现共性,即概率分布更为平稳,来决定由哪个特征作为该层的根节点。

2、 从剪枝的角度来说(后面CART/C4.5具体描述):

前剪枝:在创建时就设置以某些条件来避免过拟合的生长

后剪枝:在决策树生成后修剪

关于2后面的决策树的改进版C4.5、CART讨论。

关于1,是ID3决策树的创建原则,这就要引入”信息增益”的概念。

2.3、信息增益

         信息增益的定义:一个特征能够为分类系统带来多少信息,带来的信息越多,该特征越重要。它的计算方式是通过熵。

         进一步就是样本数据中的特征A,特征A的信息增益 = 样本数据的熵 - 它的各个取值里的条件熵之和,它的各个取值里的取值概率 *条件熵之和越小,则信息增益越大,则特征A越应该成为当前样本的决策树的根节点或者说,作为根特征的特征A,其各个取值必须和各自的结果,有更强的相关性。

         条件熵:在特征A的某个取值不变时,得到的子样本数据的熵。

即如何确定决策树的各层特征。举例样本数据如下:

老婆是否例假 天气如何 决定

是       好     玩游戏

是       不好 玩游戏

不是   好     啪啪啪

不是   不好 啪啪啪

样本熵 = 0 – 1/2 * log(1/2, 2) – 1/2 * log(1/2, 2) = 1

1、如果以”老婆是否例假”作为决策树的根特征:

                                     取值”是”: 取值概率为1/2,子样本是:

                                                        好     玩游戏

                                                        不好 玩游戏

                                     子样本的熵:0 – 1 *log(1, 2) – 1 * log(1, 2) = 0

                                     取值”不是”: 取值概率为1/2,子样本是:

                                                        好     啪啪啪

                                                        不好 啪啪啪

                                     子样本的熵:0 – 1 *log(1, 2) - 1 * log(1, 2) = 0

                                     信息增益 = 1(样本熵) – 1/2(取值”是”概率) * 0(取值”是”的子样本熵) – 1/2(取值”不是”概率) * 0(取值”不是”的子样本熵) = 1

                            2、如果以”天气如何”作为决策树的根特征:

                                     取值”好”: 取值概率为1/2,子样本是:

                                               是 玩游戏

                                               不是 啪啪啪

                                     子样本的熵:0 - 1 /2 *log(1/2, 2) - 1 /2 * log(1/2, 2) = 1,取值概率为1/2

                                     取值”不好”: 取值概率为1/2,子样本是:

                                               是 玩游戏

                                               不是 啪啪啪

                                     子样本的熵:0 - 1 /2 *log(1/2, 2) - 1 /2 * log(1/2, 2) = 1,取值概率为1/2

                                     信息增益 = 1(样本熵) – 1/2(取值”好”概率) * 1(取值”好”的子样本熵) – 1/2(取值”好”概率) * 1(取值”不好”的子样本熵) = 0

                            所以,以老婆是否例假作为决策树的根特征,比天气如何作为决策树的根特征,信息增益更大,应该以老婆是否例假作为根特征,用matplotlib画图如下:

                           

                            该图的含义是:只要老婆没有例假就啪啪啪,否则就玩游戏。反之,如果以”天气如何”作为根特征,不论天气是”好”还是”不好”,都要再根据”老婆是否在家”的情况,做出不同的决定。

                   2.4、总结

                            当样本SN个特征(N > 1),作为根特征的特征A,必须符合特征A的各个取值,满足公式:min(sum(p(i)* Ent(S|A = Ai))),含义是:作为根特征的特征A,它的每个取值,都要尽可能分散度更小(熵更小)的结果

3、递归决策树

                   回到最开始的样本,这不是自黑,是一个真实的样本:

                            技术能力 积极度 年龄 前途

                   6 8 old normal

8 9 old yes

3 3 old no

7 5 old normal

7 7 young normal

7 6 old normal

8 5 old normal

2 2 old no

7 5 old normal

6 6 young normal

7 4 old normal

8 4 old normal

4 3 old no

5 4 old no

5 4 old no

6 4 old no

6 3 old normal

7 8 young yes

6 8 young yes

6 5 old no

共有3个特征,技术能力'tech', 积极程度'ispositive', 年龄'age',作为决策树,按上面描述的方法,可以计算出根特征。根特征是”ispositive”。然后就需要计算在根特征是”ispositive”的各种取值下,哪个特征作为接下来的根特征。举例如下:比如说,决定职业球员能否取得成功,根特征是”身体素质”,那么在身体素质打9分的情况下,还有其他的特征进一步决定能否取得多大的成功,比如”职业态度”,在身体素质9分职业态度9分的情况下,还会有很多因素进一步影响能取得多大成功,事实上现实生活中,每一个结果也确实都是由多种多样的因素最终决定的。

ID3递归决策树,就是通过概率和熵,由min(sum(p(i)* Ent(S|A = Ai)))这个结论,一层一层的计算出当前样本中最具广泛意义的特征,根据其不同的取值,进一步收缩样本,再计算收缩后样本的最具广泛意义的特征,直到找到最终结果。

下面直接给出程序:

训练数据就是上面的数据,制表符分隔。

#coding:utf8
fromnumpy import *
frommath import log
importsys
importoperator
fromtreeplot import *
#status.txt就是上面的训练数据
def createdataset ():#dataset = [[1, 1, 'yes'], [1, 1,'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]#labels = ['no surfacing', 'flippers']f = open('status.txt')items = {}dataset = []while 1:l = f.readline().strip('\n')if l == "":breakary = l.split('\t')name = ary[0]items[name] = ary[1:]dataset.append(ary[1:])f.close()labels = ['tech', 'is positive', 'age']#labels = ['no problem', 'weather']return dataset, labels
	
 
#熵越大,说明情况越复杂,反之,什么情况越清晰
def calcEnt (dataset):map = {}for data in dataset:label = data[-1]if map.has_key(label):map[label] += 1else:map[label] = 1ent = 0.0for label in map:p = float(map[label])/float(len(dataset))print label, p, p * log(p, 2)ent -= p * log(p, 2)return ent
def split_by_feature (dataset, idx, val):res = []for data in dataset:if data[idx] == val:vec = data[:idx]vec.extend(data[idx + 1:])res.append(vec)return res
#对每个特征进行分析,计算每个特征的信息增益,每个取值的"概率 * 该特征值下子数据的熵"的和,找出变化最小即最稳定的是哪个特征
def find_bestfeature_tosplit_dataset (dataset, labels):feature_num = len(dataset[0]) - 1bestentgain = 0.0bestfeature = -1ent = calcEnt(dataset)#对每个特征进行分析#print "ent: %f" % entfor i in range(feature_num):values = set([data[i] for data in dataset])cur_ent = 0.0#print "\nfeature %s" % labels[i]#计算每个特征的信息增益,每个取值的"概率 * 该特征值下子数据的熵"的和,找出变化最小即最稳定的是哪个特征for value in values:res = split_by_feature(dataset, i, value)p = float(len(res))/float(len(dataset))cur_ent += p * calcEnt(res)#print "value %s, p(%f), ent(%f)" % (value, p, cur_ent)#print resentgain = ent - cur_ent#print "%s,  entgain(%f)" % (labels[i], entgain)#entgain越大,即cur_ent越小,即(熵*概率)越小,即该情况越清晰if entgain > bestentgain:bestentgain = entgainbestfeature = ireturn bestfeature
def vote(classes):classcount = {}for vote in classes:if classcount.has_key(vote):classcount[vote] += 1else:classcount[vote] = 1sortedclasscount  = sorted(classcount.iteritems(), key = operator.itemgetter(1), reverse = True)return sortedclasscount[0][0]
#递归决策树,依次找每个变化最稳定的特征,构成决策树
def createdtree (dataset, labels):#classes是当前所有的结果classes = [data[-1] for data in dataset]#如果当前都没有特征了,只剩下结果了,那就简单的看下哪个结果多就算是哪个if len(dataset[0]) == 1:#print "no feature"return vote(classes)#就一种结果了,不用计算什么根特征了,肯定就这个结果if len(classes) == classes.count(classes[0]):#print "direct result %s" % (classes[0]), datasetreturn classes[0]#计算根特征,进而构建当前的决策树bestfeatureidx = find_bestfeature_tosplit_dataset(dataset, labels)bestfeature = labels[bestfeatureidx]tree = {labels[bestfeatureidx]:{}}#print "best feature: %s" % bestfeature, datasetvalues = set(data[bestfeatureidx] for data in dataset)#当前特征已为决策树的根特征,干掉del(labels[bestfeatureidx])#当前根特征下,各个特征值的子数据的再决策for value in values:#这里千万不可以newlabels = labels,这样是引用,会破坏递归前labels。要newlabels = labels[:],这样是拷贝newlabels = labels[:]print value, newlabels#按根特征的当前的取值,获取收缩后的样本newdataset = split_by_feature(dataset, bestfeatureidx, value)tree[bestfeature][value] = createdtree(newdataset, newlabels)return tree
if __name__ == "__main__":#加载训练样本数据dataset, labels = createdataset()#构建递归决策树tree = createdtree(dataset, labels)#画图createPlot(tree)

关于matplotlib画决策层的图,直接贴出程序,暂先不讨论细节,matplotlib可能需要作为一个大专题来讨论。这是一个比较通用的程序,接收决策树参数即可使用。

#coding: utf8
import matplotlib.pyplot as plt#定义文本框和箭头格式  
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #定义判断节点形态  
leafNode = dict(boxstyle="round4", fc="0.8") #定义叶节点形态  
arrow_args = dict(arrowstyle="<-") #定义箭头  #绘制带箭头的注解  
#nodeTxt:节点的文字标注, centerPt:节点中心位置,  
#parentPt:箭头起点位置(上一节点位置), nodeType:节点属性  
def plotNode(nodeTxt, centerPt, parentPt, nodeType):  createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',  xytext=centerPt, textcoords='axes fraction',  va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )#计算叶节点数  
def getNumLeafs(myTree):  numLeafs = 0  firstStr = myTree.keys()[0]   secondDict = myTree[firstStr]   for key in secondDict.keys():  if type(secondDict[key]).__name__=='dict':#是否是字典  numLeafs += getNumLeafs(secondDict[key]) #递归调用getNumLeafs  else:   numLeafs +=1 #如果是叶节点,则叶节点+1  return numLeafs  #计算数的层数  
def getTreeDepth(myTree):  maxDepth = 0  firstStr = myTree.keys()[0]  secondDict = myTree[firstStr]  for key in secondDict.keys():  if type(secondDict[key]).__name__=='dict':#是否是字典  thisDepth = 1 + getTreeDepth(secondDict[key]) #如果是字典,则层数加1,再递归调用getTreeDepth  else:   thisDepth = 1  #得到最大层数  if thisDepth > maxDepth:  maxDepth = thisDepth  return maxDepth#在父子节点间填充文本信息  
#cntrPt:子节点位置, parentPt:父节点位置, txtString:标注内容  
def plotMidText(cntrPt, parentPt, txtString):  xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)#绘制树形图  
#myTree:树的字典, parentPt:父节点, nodeTxt:节点的文字标注  
def plotTree(myTree, parentPt, nodeTxt):  numLeafs = getNumLeafs(myTree)  #树叶节点数  depth = getTreeDepth(myTree)    #树的层数  firstStr = myTree.keys()[0]     #节点标签  #计算当前节点的位置  cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  plotMidText(cntrPt, parentPt, nodeTxt) #在父子节点间填充文本信息  plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制带箭头的注解  secondDict = myTree[firstStr]  plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  for key in secondDict.keys():  if type(secondDict[key]).__name__=='dict':#判断是不是字典,  plotTree(secondDict[key],cntrPt,str(key))        #递归绘制树形图  else:   #如果是叶节点  plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  def createPlot(inTree):  fig = plt.figure(1, facecolor='white')  fig.clf()  axprops = dict(xticks=[], yticks=[])  createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)      plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度  plotTree.totalD = float(getTreeDepth(inTree)) #树的深度  plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  plotTree(inTree, (0.5,1.0), '')  plt.show()  

结果如下图:



但可以感觉到,分支很多,有过拟合的感觉。针对这个样本,主要是训练样本数据的原因,但也引出了决策树如何避免过拟合的问题,接下来就需要学习更广泛的CART/C4.5决策树改进算法, 以及决策树的剪枝。

这篇关于掰开揉碎机器学习系列-决策树(1)-ID3决策树的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/815499

相关文章

HarmonyOS学习(七)——UI(五)常用布局总结

自适应布局 1.1、线性布局(LinearLayout) 通过线性容器Row和Column实现线性布局。Column容器内的子组件按照垂直方向排列,Row组件中的子组件按照水平方向排列。 属性说明space通过space参数设置主轴上子组件的间距,达到各子组件在排列上的等间距效果alignItems设置子组件在交叉轴上的对齐方式,且在各类尺寸屏幕上表现一致,其中交叉轴为垂直时,取值为Vert

Ilya-AI分享的他在OpenAI学习到的15个提示工程技巧

Ilya(不是本人,claude AI)在社交媒体上分享了他在OpenAI学习到的15个Prompt撰写技巧。 以下是详细的内容: 提示精确化:在编写提示时,力求表达清晰准确。清楚地阐述任务需求和概念定义至关重要。例:不用"分析文本",而用"判断这段话的情感倾向:积极、消极还是中性"。 快速迭代:善于快速连续调整提示。熟练的提示工程师能够灵活地进行多轮优化。例:从"总结文章"到"用

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

【前端学习】AntV G6-08 深入图形与图形分组、自定义节点、节点动画(下)

【课程链接】 AntV G6:深入图形与图形分组、自定义节点、节点动画(下)_哔哩哔哩_bilibili 本章十吾老师讲解了一个复杂的自定义节点中,应该怎样去计算和绘制图形,如何给一个图形制作不间断的动画,以及在鼠标事件之后产生动画。(有点难,需要好好理解) <!DOCTYPE html><html><head><meta charset="UTF-8"><title>06

学习hash总结

2014/1/29/   最近刚开始学hash,名字很陌生,但是hash的思想却很熟悉,以前早就做过此类的题,但是不知道这就是hash思想而已,说白了hash就是一个映射,往往灵活利用数组的下标来实现算法,hash的作用:1、判重;2、统计次数;

零基础学习Redis(10) -- zset类型命令使用

zset是有序集合,内部除了存储元素外,还会存储一个score,存储在zset中的元素会按照score的大小升序排列,不同元素的score可以重复,score相同的元素会按照元素的字典序排列。 1. zset常用命令 1.1 zadd  zadd key [NX | XX] [GT | LT]   [CH] [INCR] score member [score member ...]

科研绘图系列:R语言扩展物种堆积图(Extended Stacked Barplot)

介绍 R语言的扩展物种堆积图是一种数据可视化工具,它不仅展示了物种的堆积结果,还整合了不同样本分组之间的差异性分析结果。这种图形表示方法能够直观地比较不同物种在各个分组中的显著性差异,为研究者提供了一种有效的数据解读方式。 加载R包 knitr::opts_chunk$set(warning = F, message = F)library(tidyverse)library(phyl

【机器学习】高斯过程的基本概念和应用领域以及在python中的实例

引言 高斯过程(Gaussian Process,简称GP)是一种概率模型,用于描述一组随机变量的联合概率分布,其中任何一个有限维度的子集都具有高斯分布 文章目录 引言一、高斯过程1.1 基本定义1.1.1 随机过程1.1.2 高斯分布 1.2 高斯过程的特性1.2.1 联合高斯性1.2.2 均值函数1.2.3 协方差函数(或核函数) 1.3 核函数1.4 高斯过程回归(Gauss

【生成模型系列(初级)】嵌入(Embedding)方程——自然语言处理的数学灵魂【通俗理解】

【通俗理解】嵌入(Embedding)方程——自然语言处理的数学灵魂 关键词提炼 #嵌入方程 #自然语言处理 #词向量 #机器学习 #神经网络 #向量空间模型 #Siri #Google翻译 #AlexNet 第一节:嵌入方程的类比与核心概念【尽可能通俗】 嵌入方程可以被看作是自然语言处理中的“翻译机”,它将文本中的单词或短语转换成计算机能够理解的数学形式,即向量。 正如翻译机将一种语言

【学习笔记】 陈强-机器学习-Python-Ch15 人工神经网络(1)sklearn

系列文章目录 监督学习:参数方法 【学习笔记】 陈强-机器学习-Python-Ch4 线性回归 【学习笔记】 陈强-机器学习-Python-Ch5 逻辑回归 【课后题练习】 陈强-机器学习-Python-Ch5 逻辑回归(SAheart.csv) 【学习笔记】 陈强-机器学习-Python-Ch6 多项逻辑回归 【学习笔记 及 课后题练习】 陈强-机器学习-Python-Ch7 判别分析 【学