Spark2.x 入门:决策树分类器

2024-09-06 15:52

本文主要是介绍Spark2.x 入门:决策树分类器,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、方法简介 ​

决策树(decision tree)是一种基本的分类与回归方法,这里主要介绍用于分类的决策树。决策树模式呈树形结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。学习时利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策树模型进行分类。

决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的剪枝。

示例代码

我们以iris数据集(iris)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。决策树可以用于分类和回归,接下来我们将在代码中分别进行介绍。

1. 导入需要的包:

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector,Vectors}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

2. 读取数据,简要分析:

导入spark.implicits._,使其支持把一个RDD隐式转换为一个DataFrame。我们用case class定义一个schema:Iris,Iris就是我们需要的数据的结构;然后读取文本文件,第一个map把每行的数据用“,”隔开,比如在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类;我们这里把特征存储在Vector中,创建一个Iris模式的RDD,然后转化成dataframe;然后把刚刚得到的数据注册成一个表iris,注册成这个表之后,我们就可以通过sql语句进行数据查询;选出我们需要的数据后,我们可以把结果打印出来查看一下数据。

scala> import spark.implicits._
import spark.implicits._scala> case class Iris(features: org.apache.spark.ml.linalg.Vector, label: String)
defined class Irisscala> val data = spark.read.textFile("file:///root/data/iris.txt").map(_.split(",")).map(p => Iris(Vectors.dense(p(0).toDouble,p(1).toDouble,p(2).toDouble,p(3).toDouble),p(4).toString())).toDF()scala> data.createOrReplaceTempView("iris")scala> val df = spark.sql("select * from iris")
df: org.apache.spark.sql.DataFrame = [features: vector, label: string]scala> df.map(t => t(1)+":"+t(0)).collect().foreach(println)
Iris-setosa:[5.1,3.5,1.4,0.2]
Iris-setosa:[4.9,3.0,1.4,0.2]
Iris-setosa:[4.7,3.2,1.3,0.2]
Iris-setosa:[4.6,3.1,1.5,0.2]
Iris-setosa:[5.0,3.6,1.4,0.2]
Iris-setosa:[5.4,3.9,1.7,0.4]
Iris-setosa:[4.6,3.4,1.4,0.3]
......

3. 进一步处理特征和标签,以及数据分组:

//分别获取标签列和特征列,进行索引,并进行了重命名。
scala> val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df) 
labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_6c3c138d61bfscala> val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(df)
featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_08c01d7fd953//这里我们设置一个labelConverter,目的是把预测的类别重新转化成字符型的。
scala> val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_11ce3220e43a//接下来,我们把数据集随机分成训练集和测试集,其中训练集占70%。
scala> val Array(trainingData, testData) = df.randomSplit(Array(0.7, 0.3))
trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [features: vector, label: string]

4. 构建决策树分类模型:

//导入所需要的包
scala> import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassificationModelscala> import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassifierscala> import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator//训练决策树模型,这里我们可以通过setter的方法来设置决策树的参数,也可以用ParamMap来设置(具体的可以查看spark mllib的官网)。具体的可以设置的参数可以通过explainParams()来获取。
scala> val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtClassifier: org.apache.spark.ml.classification.DecisionTreeClassifier = dtc_7948c1724433//在pipeline中进行设置
scala> val pipelinedClassifier = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtClassifier, labelConverter))
pipelinedClassifier: org.apache.spark.ml.Pipeline = pipeline_b5a49e693b35//训练决策树模型
scala> val modelClassifier = pipelinedClassifier.fit(trainingData)
modelClassifier: org.apache.spark.ml.PipelineModel = pipeline_b5a49e693b35//进行预测
scala> val predictionsClassifier = modelClassifier.transform(testData)
predictionsClassifier: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 6 more fields]//查看部分预测的结果
scala> predictionsClassifier.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

5. 评估决策树分类模型:

scala> val evaluatorClassifier = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
evaluatorClassifier: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_8059f30a8634scala> val accuracy = evaluatorClassifier.evaluate(predictionsClassifier)
accuracy: Double = 0.94scala> println("Test Error = " + (1.0 - accuracy))
Test Error = 0.06000000000000005scala> val treeModelClassifier = modelClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
treeModelClassifier: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodesscala> println("Learned classification tree model:\n" + treeModelClassifier.toDebugString)
Learned classification tree model:
DecisionTreeClassificationModel (uid=dtc_7948c1724433) of depth 4 with 13 nodesIf (feature 2 <= 1.9)Predict: 0.0Else (feature 2 > 1.9)If (feature 3 <= 1.6)If (feature 2 <= 4.9)Predict: 1.0Else (feature 2 > 4.9)If (feature 0 <= 6.0)Predict: 1.0Else (feature 0 > 6.0)Predict: 2.0Else (feature 3 > 1.6)If (feature 2 <= 4.8)If (feature 1 <= 2.8)Predict: 2.0Else (feature 1 > 2.8)Predict: 1.0Else (feature 2 > 4.8)Predict: 2.0

从上述结果可以看到模型的预测准确率为 0.94 以及训练的决策树模型结构。

6. 构建决策树回归模型:

//导入所需要的包
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.regression.DecisionTreeRegressor//训练决策树模型
scala> val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtRegressor: org.apache.spark.ml.regression.DecisionTreeRegressor = dtr_e98e9ef10e22//在pipeline中进行设置
scala> val pipelineRegressor = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dtRegressor, labelConverter))
pipelineRegressor: org.apache.spark.ml.Pipeline = pipeline_9f0fb530c801//训练决策树模型
scala> val modelRegressor = pipelineRegressor.fit(trainingData)
modelRegressor: org.apache.spark.ml.PipelineModel = pipeline_9f0fb530c801//进行预测
scala> val predictionsRegressor = modelRegressor.transform(testData)
predictionsRegressor: org.apache.spark.sql.DataFrame = [features: vector, label: string ... 4 more fields]//查看部分预测结果
scala> predictionsRegressor.select("predictedLabel", "label", "features").show(20)
+---------------+---------------+-----------------+
| predictedLabel|          label|         features|
+---------------+---------------+-----------------+
|    Iris-setosa|    Iris-setosa|[4.4,2.9,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[4.6,3.4,1.4,0.3]|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|
|    Iris-setosa|    Iris-setosa|[4.7,3.2,1.6,0.2]|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|
|Iris-versicolor|Iris-versicolor|[5.0,2.3,3.3,1.0]|
|    Iris-setosa|    Iris-setosa|[5.0,3.2,1.2,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|
|    Iris-setosa|    Iris-setosa|[5.0,3.4,1.6,0.4]|
|    Iris-setosa|    Iris-setosa|[5.1,3.3,1.7,0.5]|
|    Iris-setosa|    Iris-setosa|[5.1,3.7,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.5,0.4]|
|    Iris-setosa|    Iris-setosa|[5.4,3.9,1.7,0.4]|
|Iris-versicolor|Iris-versicolor|[5.5,2.3,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.5,4.0,1.3]|
|Iris-versicolor|Iris-versicolor|[5.5,2.6,4.4,1.2]|
|    Iris-setosa|    Iris-setosa|[5.5,4.2,1.4,0.2]|
+---------------+---------------+-----------------+
only showing top 20 rows

7. 评估决策树回归模型:

scala> val evaluatorRegressor = new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse")
evaluatorRegressor: org.apache.spark.ml.evaluation.RegressionEvaluator = regEval_162861380a26scala> val rmse = evaluatorRegressor.evaluate(predictionsRegressor)
rmse: Double = 0.2449489742783178scala> println("Root Mean Squared Error (RMSE) on test data = " + rmse)
Root Mean Squared Error (RMSE) on test data = 0.2449489742783178scala> val treeModelRegressor = modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel]
treeModelRegressor: org.apache.spark.ml.regression.DecisionTreeRegressionModel = DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodesscala> println("Learned regression tree model:\n" + treeModelRegressor.toDebugString)
Learned regression tree model:
DecisionTreeRegressionModel (uid=dtr_e98e9ef10e22) of depth 4 with 13 nodesIf (feature 2 <= 1.9)Predict: 0.0Else (feature 2 > 1.9)If (feature 3 <= 1.6)If (feature 2 <= 4.9)Predict: 1.0Else (feature 2 > 4.9)If (feature 0 <= 6.0)Predict: 1.0Else (feature 0 > 6.0)Predict: 2.0Else (feature 3 > 1.6)If (feature 2 <= 4.8)If (feature 1 <= 2.8)Predict: 2.0Else (feature 1 > 2.8)Predict: 1.0Else (feature 2 > 4.8)Predict: 2.0

这篇关于Spark2.x 入门:决策树分类器的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1142453

相关文章

Spring Security 从入门到进阶系列教程

Spring Security 入门系列 《保护 Web 应用的安全》 《Spring-Security-入门(一):登录与退出》 《Spring-Security-入门(二):基于数据库验证》 《Spring-Security-入门(三):密码加密》 《Spring-Security-入门(四):自定义-Filter》 《Spring-Security-入门(五):在 Sprin

数论入门整理(updating)

一、gcd lcm 基础中的基础,一般用来处理计算第一步什么的,分数化简之类。 LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; } <pre name="code" class="cpp">LL lcm(LL a, LL b){LL c = gcd(a, b);return a / c * b;} 例题:

Java 创建图形用户界面(GUI)入门指南(Swing库 JFrame 类)概述

概述 基本概念 Java Swing 的架构 Java Swing 是一个为 Java 设计的 GUI 工具包,是 JAVA 基础类的一部分,基于 Java AWT 构建,提供了一系列轻量级、可定制的图形用户界面(GUI)组件。 与 AWT 相比,Swing 提供了许多比 AWT 更好的屏幕显示元素,更加灵活和可定制,具有更好的跨平台性能。 组件和容器 Java Swing 提供了许多

【IPV6从入门到起飞】5-1 IPV6+Home Assistant(搭建基本环境)

【IPV6从入门到起飞】5-1 IPV6+Home Assistant #搭建基本环境 1 背景2 docker下载 hass3 创建容器4 浏览器访问 hass5 手机APP远程访问hass6 更多玩法 1 背景 既然电脑可以IPV6入站,手机流量可以访问IPV6网络的服务,为什么不在电脑搭建Home Assistant(hass),来控制你的设备呢?@智能家居 @万物互联

poj 2104 and hdu 2665 划分树模板入门题

题意: 给一个数组n(1e5)个数,给一个范围(fr, to, k),求这个范围中第k大的数。 解析: 划分树入门。 bing神的模板。 坑爹的地方是把-l 看成了-1........ 一直re。 代码: poj 2104: #include <iostream>#include <cstdio>#include <cstdlib>#include <al

MySQL-CRUD入门1

文章目录 认识配置文件client节点mysql节点mysqld节点 数据的添加(Create)添加一行数据添加多行数据两种添加数据的效率对比 数据的查询(Retrieve)全列查询指定列查询查询中带有表达式关于字面量关于as重命名 临时表引入distinct去重order by 排序关于NULL 认识配置文件 在我们的MySQL服务安装好了之后, 会有一个配置文件, 也就

音视频入门基础:WAV专题(10)——FFmpeg源码中计算WAV音频文件每个packet的pts、dts的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以打印WAV音频文件每个packet(也称为数据包或多媒体包)的信息,这些信息包含该packet的pts、dts: 打印出来的“pts”实际是AVPacket结构体中的成员变量pts,是以AVStream->time_base为单位的显

C语言指针入门 《C语言非常道》

C语言指针入门 《C语言非常道》 作为一个程序员,我接触 C 语言有十年了。有的朋友让我推荐 C 语言的参考书,我不敢乱推荐,尤其是国内作者写的书,往往七拼八凑,漏洞百出。 但是,李忠老师的《C语言非常道》值得一读。对了,李老师有个官网,网址是: 李忠老师官网 最棒的是,有配套的教学视频,可以试看。 试看点这里 接下来言归正传,讲解指针。以下内容很多都参考了李忠老师的《C语言非

MySQL入门到精通

一、创建数据库 CREATE DATABASE 数据库名称; 如果数据库存在,则会提示报错。 二、选择数据库 USE 数据库名称; 三、创建数据表 CREATE TABLE 数据表名称; 四、MySQL数据类型 MySQL支持多种类型,大致可以分为三类:数值、日期/时间和字符串类型 4.1 数值类型 数值类型 类型大小用途INT4Bytes整数值FLOAT4By

【QT】基础入门学习

文章目录 浅析Qt应用程序的主函数使用qDebug()函数常用快捷键Qt 编码风格信号槽连接模型实现方案 信号和槽的工作机制Qt对象树机制 浅析Qt应用程序的主函数 #include "mywindow.h"#include <QApplication>// 程序的入口int main(int argc, char *argv[]){// argc是命令行参数个数,argv是