本文主要是介绍决策树预测隐形眼镜的类型(python机器学习记录),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
使用决策树预测隐形眼镜类型
tree.py(创建树的代码)
# -*- coding: UTF-8 -*-
from math import log
import operator#计算数据集的香农熵 熵越高数据越乱
def calcShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntriesshannonEnt -= prob * log(prob, 2)return shannonEnt#划分数据集
def splitDataSet(dataSet,axis,value):retDataSet=[]for featVec in dataSet:if featVec[axis]==value:reducedFeatVec=featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)# print retDataSetreturn retDataSet#选择最好的数据集划分方式
def chooseBestFeatureTosplit(dataSet):numFeatures=len(dataSet[0])-1baseEntropy=calcShannonEnt(dataSet)bestInfoGain=0.0;bestFeature=-1for i in range(numFeatures):featList=[example[i] for example in dataSet]uniqueVals=set(featList)newEntorpy=0.0for value in uniqueVals:subDataSet =splitDataSet(dataSet,i,value)prob=len(subDataSet)/float(len(dataSet))newEntorpy+=prob*calcShannonEnt(subDataSet)infoGain=baseEntropy-newEntorpyif(infoGain>bestInfoGain):bestInfoGain=infoGainbestFeature=ireturn bestFeature
def majorityCnt(classList):classCount={}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]#创建树
def createTree(dataSet,labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0]#stop splitting when all of the classes are equalif len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSetreturn majorityCnt(classList)bestFeat = chooseBestFeatureTosplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:] #copy all of labels, so trees don't mess up existing labelsmyTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)return myTree
treePlotter.py(绘制树的代码)
# -*- coding: UTF-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('gbk')import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")#获取叶节点的数目
def getNumLeafs(myTree):numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodesnumLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafs#获取叶节点的层数
def getTreeDepth(myTree):maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodesthisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepthreturn maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split onnumLeafs = getNumLeafs(myTree) # this determines the x width of this treedepth = getTreeDepth(myTree)firstStr = myTree.keys()[0] # the text label for this node should be thiscntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodesplotTree(secondDict[key], cntrPt, str(key)) # recursionelse: # it's a leaf node print the leaf nodeplotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD#主函数
def createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropsesplotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalWplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()
- 收集数据:lenses.txt
young myope no reduced no lenses young myope no normal soft young myope yes reduced no lenses young myope yes normal hard young hyper no reduced no lenses young hyper no normal soft young hyper yes reduced no lenses young hyper yes normal hard pre myope no reduced no lenses pre myope no normal soft pre myope yes reduced no lenses pre myope yes normal hard pre hyper no reduced no lenses pre hyper no normal soft pre hyper yes reduced no lenses pre hyper yes normal no lenses presbyopic myope no reduced no lenses presbyopic myope no normal no lenses presbyopic myope yes reduced no lenses presbyopic myope yes normal hard presbyopic hyper no reduced no lenses presbyopic hyper no normal soft presbyopic hyper yes reduced no lenses presbyopic hyper yes normal no lenses
- 准备数据:解析tab键分隔符的
- 分析数据:快速检测数据,确保正确的解析数据内容,使用createPlot()函数绘制最终的树形图。
- 训练算法:使用createTree()函数
- 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
- 使用算法:存储数的数据结构,以便下次使用时无需重新构造树
>>>import tree
>>>import treePlotter
>>>fr=open("lenses.txt")
>>>lenses=[inst.strip().split("\t") for inst in fr.readlines()]
>>>lensesLabels=["age","prescript","astigmatic","tearRte"]
>>>lensesTree=tree.createTree(lenses,lensesLabels)
>>>treePlotter.createPlot(lensesTree)
绘制的决策树:
参考书籍:机器学习实战
这篇关于决策树预测隐形眼镜的类型(python机器学习记录)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!