Spark2.x 入门:决策树分类器

2024-09-06 15:52

本文主要是介绍Spark2.x 入门:决策树分类器,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、方法简介 ​

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的剪枝。

示例代码

我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。决策树可以用于分类和回归,接下来我们将在代码中分别进行介绍。

1. 导入需要的包:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector,Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

2. 读取数据,简要分析:

导入spark.implicits._,使其支持把一个RDD隐式转换为一个DataFrame。我们用case class定义一个schema:Iris,Iris就是我们需要的数据的结构;然后读取文本文件,第一个map把每行的数据用“,”隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;然后把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询;选出我们需要的数据后,我们可以把结果打印出来查看一下数据。

scala> import spark.implicits._
import spark.implicits._scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
defined class Irisscala> val data = spark.read.textFile("file:///root/data/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString())).toDF()scala> data.createOrReplaceTempView("iris")scala> val df = spark.sql("select * from iris")
df: org.apache.spark.sql.DataFrame = [features: vector, label: string]scala> df.map(t => t(1)+":"+t(0)).collect().foreach(println)
Iris-setosa:[5.1,3.5,1.4,0.2]
Iris-setosa:[4.9,3.0,1.4,0.2]
Iris-setosa:[4.7,3.2,1.3,0.2]
Iris-setosa:[4.6,3.1,1.5,0.2]
Iris-setosa:[5.0,3.6,1.4,0.2]
Iris-setosa:[5.4,3.9,1.7,0.4]
Iris-setosa:[4.6,3.4,1.4,0.3]
......

3. 进一步处理特征和标签,以及数据分组:

//分别获取标签列和特征列,进行索引,并进行了重命名。
scala> val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df) 
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_6c3c138d61bfscala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(df)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_08c01d7fd953//这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。
scala> val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_11ce3220e43a//接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。
scala> val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]

4. 构建决策树分类模型:

//导入所需要的包
scala> import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassificationModelscala> import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassifierscala> import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator//训练决策树模型,这里我们可以通过setter的方法来设置决策树的参数,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。具体的可以设置的参数可以通过explainParams()来获取。
scala> val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtClassifier: org.apache.spark.ml.classification.DecisionTreeClassifier = dtc_7948c1724433//在pipeline中进行设置
scala> val pipelinedClassifier = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
pipelinedClassifier: org.apache.spark.ml.Pipeline = pipeline_b5a49e693b35//训练决策树模型
scala> val modelClassifier = pipelinedClassifier.fit(trainingData)
modelClassifier: org.apache.spark.ml.PipelineModel = pipeline_b5a49e693b35//进行预测
scala> val predictionsClassifier = modelClassifier.transform(testData)
predictionsClassifier: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]//查看部分预测的结果
scala> predictionsClassifier.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

5. 评估决策树分类模型:

scala> val evaluatorClassifier = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
evaluatorClassifier: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_8059f30a8634scala> val accuracy = evaluatorClassifier.evaluate(predictionsClassifier)
accuracy: Double = 0.94scala> println("Test Error = " + (1.0 - accuracy))
Test Error = 0.06000000000000005scala> val treeModelClassifier = modelClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
treeModelClassifier: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodesscala> println("Learned classification tree model:\n" + treeModelClassifier.toDebugString)
Learned classification tree model:
DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodesIf (feature 2 <= 1.9)Predict: 0.0Else (feature 2 > 1.9)If (feature 3 <= 1.6)If (feature 2 <= 4.9)Predict: 1.0Else (feature 2 > 4.9)If (feature 0 <= 6.0)Predict: 1.0Else (feature 0 > 6.0)Predict: 2.0Else (feature 3 > 1.6)If (feature 2 <= 4.8)If (feature 1 <= 2.8)Predict: 2.0Else (feature 1 > 2.8)Predict: 1.0Else (feature 2 > 4.8)Predict: 2.0

从上述结果可以看到模型的预测准确率为 0.94 以及训练的决策树模型结构。

6. 构建决策树回归模型:

//导入所需要的包
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.regression.DecisionTreeRegressor//训练决策树模型
scala> val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtRegressor: org.apache.spark.ml.regression.DecisionTreeRegressor = dtr_e98e9ef10e22//在pipeline中进行设置
scala> val pipelineRegressor = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtRegressor, labelConverter))
pipelineRegressor: org.apache.spark.ml.Pipeline = pipeline_9f0fb530c801//训练决策树模型
scala> val modelRegressor = pipelineRegressor.fit(trainingData)
modelRegressor: org.apache.spark.ml.PipelineModel = pipeline_9f0fb530c801//进行预测
scala> val predictionsRegressor = modelRegressor.transform(testData)
predictionsRegressor: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 4 more fields]//查看部分预测结果
scala> predictionsRegressor.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

7. 评估决策树回归模型:

scala> val evaluatorRegressor = new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse")
evaluatorRegressor: org.apache.spark.ml.evaluation.RegressionEvaluator = regEval_162861380a26scala> val rmse = evaluatorRegressor.evaluate(predictionsRegressor)
rmse: Double = 0.2449489742783178scala> println("Root Mean Squared Error (RMSE) on test data = " + rmse)
Root Mean Squared Error (RMSE) on test data = 0.2449489742783178scala> val treeModelRegressor = modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel]
treeModelRegressor: org.apache.spark.ml.regression.DecisionTreeRegressionModel = DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodesscala> println("Learned regression tree model:\n" + treeModelRegressor.toDebugString)
Learned regression tree model:
DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodesIf (feature 2 <= 1.9)Predict: 0.0Else (feature 2 > 1.9)If (feature 3 <= 1.6)If (feature 2 <= 4.9)Predict: 1.0Else (feature 2 > 4.9)If (feature 0 <= 6.0)Predict: 1.0Else (feature 0 > 6.0)Predict: 2.0Else (feature 3 > 1.6)If (feature 2 <= 4.8)If (feature 1 <= 2.8)Predict: 2.0Else (feature 1 > 2.8)Predict: 1.0Else (feature 2 > 4.8)Predict: 2.0

这篇关于Spark2.x 入门:决策树分类器的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1142453

相关文章

从入门到精通MySQL联合查询

《从入门到精通MySQL联合查询》:本文主要介绍从入门到精通MySQL联合查询,本文通过实例代码给大家介绍的非常详细,需要的朋友可以参考下... 目录摘要1. 多表联合查询时mysql内部原理2. 内连接3. 外连接4. 自连接5. 子查询6. 合并查询7. 插入查询结果摘要前面我们学习了数据库设计时要满

从入门到精通C++11 <chrono> 库特性

《从入门到精通C++11<chrono>库特性》chrono库是C++11中一个非常强大和实用的库,它为时间处理提供了丰富的功能和类型安全的接口,通过本文的介绍,我们了解了chrono库的基本概念... 目录一、引言1.1 为什么需要<chrono>库1.2<chrono>库的基本概念二、时间段(Durat

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

从入门到精通MySQL 数据库索引(实战案例)

《从入门到精通MySQL数据库索引(实战案例)》索引是数据库的目录,提升查询速度,主要类型包括BTree、Hash、全文、空间索引,需根据场景选择,建议用于高频查询、关联字段、排序等,避免重复率高或... 目录一、索引是什么?能干嘛?核心作用:二、索引的 4 种主要类型(附通俗例子)1. BTree 索引(

Redis 配置文件使用建议redis.conf 从入门到实战

《Redis配置文件使用建议redis.conf从入门到实战》Redis配置方式包括配置文件、命令行参数、运行时CONFIG命令,支持动态修改参数及持久化,常用项涉及端口、绑定、内存策略等,版本8... 目录一、Redis.conf 是什么?二、命令行方式传参(适用于测试)三、运行时动态修改配置(不重启服务

MySQL DQL从入门到精通

《MySQLDQL从入门到精通》通过DQL,我们可以从数据库中检索出所需的数据,进行各种复杂的数据分析和处理,本文将深入探讨MySQLDQL的各个方面,帮助你全面掌握这一重要技能,感兴趣的朋友跟随小... 目录一、DQL 基础:SELECT 语句入门二、数据过滤:WHERE 子句的使用三、结果排序:ORDE

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

POI从入门到实战轻松完成EasyExcel使用及Excel导入导出功能

《POI从入门到实战轻松完成EasyExcel使用及Excel导入导出功能》ApachePOI是一个流行的Java库,用于处理MicrosoftOffice格式文件,提供丰富API来创建、读取和修改O... 目录前言:Apache POIEasyPoiEasyExcel一、EasyExcel1.1、核心特性

Python中模块graphviz使用入门

《Python中模块graphviz使用入门》graphviz是一个用于创建和操作图形的Python库,本文主要介绍了Python中模块graphviz使用入门,具有一定的参考价值,感兴趣的可以了解一... 目录1.安装2. 基本用法2.1 输出图像格式2.2 图像style设置2.3 属性2.4 子图和聚

Spring Boot + MyBatis Plus 高效开发实战从入门到进阶优化(推荐)

《SpringBoot+MyBatisPlus高效开发实战从入门到进阶优化(推荐)》本文将详细介绍SpringBoot+MyBatisPlus的完整开发流程,并深入剖析分页查询、批量操作、动... 目录Spring Boot + MyBATis Plus 高效开发实战:从入门到进阶优化1. MyBatis