Spark ML机器学习库评估指标示例

2024-05-28 02:58

本文主要是介绍Spark ML机器学习库评估指标示例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

本文主要对 Spark ML库下模型评估指标的讲解,以下代码均以Jupyter Notebook进行讲解,Spark版本为2.4.5。模型评估指标位于包org.apache.spark.ml.evaluation下。

模型评估指标是指测试集的评估指标,而不是训练集的评估指标

1、回归评估指标

RegressionEvaluator

Evaluator for regression, which expects two input columns: prediction and label.

评估指标支持以下几种:

val metricName: Param[String]

  • "rmse" (default): root mean squared error
  • "mse": mean squared error
  • "r2": R2 metric
  • "mae": mean absolute error

Examples

# import dependencies
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.evaluation.RegressionEvaluator// Load training data
val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_linear_regression_data.txt")val lr = new LinearRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)// Fit the model
val lrModel = lr.fit(training)// Summarize the model over the training set and print out some metrics
val trainingSummary = lrModel.summary
println(s"Train MSE: ${trainingSummary.meanSquaredError}")
println(s"Train RMSE: ${trainingSummary.rootMeanSquaredError}")
println(s"Train MAE: ${trainingSummary.meanAbsoluteError}")
println(s"Train r2: ${trainingSummary.r2}")val predictions = lrModel.transform(test)// 计算精度
val evaluator = new RegressionEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("mse")
val accuracy = evaluator.evaluate(predictions)
print(s"Test MSE: ${accuracy}")

输出:

Train MSE: 101.57870147367461
Train RMSE: 10.078625971513905
Train MAE: 8.108865602095849
Train r2: 0.039467152584195975Test MSE: 114.28454406581636

2、分类评估指标

2.1 BinaryClassificationEvaluator

Evaluator for binary classification, which expects two input columns: rawPrediction and label. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities).

评估指标支持以下几种:

val metricName: Param[String]
param for metric name in evaluation (supports "areaUnderROC" (default), "areaUnderPR")

Examples

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator// Load training data
val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt")val Array(train, test) = data.randomSplit(Array(0.8, 0.2))val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)// Fit the model
val lrModel = lr.fit(train)// Summarize the model over the training set and print out some metrics
val trainSummary = lrModel.summary
println(s"Train accuracy: ${trainSummary.accuracy}")
println(s"Train weightedPrecision: ${trainSummary.weightedPrecision}")
println(s"Train weightedRecall: ${trainSummary.weightedRecall}")
println(s"Train weightedFMeasure: ${trainSummary.weightedFMeasure}")val predictions = lrModel.transform(test)
predictions.show(5)// 模型评估
val evaluator = new BinaryClassificationEvaluator().setLabelCol("label").setRawPredictionCol("rawPrediction").setMetricName("areaUnderROC")
val auc = evaluator.evaluate(predictions)
print(s"Test AUC: ${auc}")val mulEvaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("weightedPrecision")
val precision = evaluator.evaluate(predictions)
print(s"Test weightedPrecision: ${precision}")

输出结果:

Train accuracy: 0.9873417721518988
Train weightedPrecision: 0.9876110961486668
Train weightedRecall: 0.9873417721518987
Train weightedFMeasure: 0.9873124561568825+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[122,123,148...|[0.29746771419036...|[0.57382336211209...|       0.0|
|  0.0|(692,[125,126,127...|[0.42262389447949...|[0.60411095396791...|       0.0|
|  0.0|(692,[126,127,128...|[0.74220898710237...|[0.67747871191347...|       0.0|
|  0.0|(692,[126,127,128...|[0.77729372618481...|[0.68509655708828...|       0.0|
|  0.0|(692,[127,128,129...|[0.70928896866149...|[0.67024402884354...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+Test AUC: 1.0Test weightedPrecision: 1.0

2.2 MulticlassClassificationEvaluator

Evaluator for multiclass classification, which expects two input columns: prediction and label.

注:既然适用于多分类,当然适用于上面的二分类

评估指标支持如下几种:

val metricName: Param[String]
param for metric name in evaluation (supports "f1" (default), "weightedPrecision", "weightedRecall", "accuracy")

Examples

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}// Load the data stored in LIBSVM format as a DataFrame.
val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt")// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
// Automatically identify categorical features, and index them.
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4) // features with > 4 distinct values are treated as continuous..fit(data)// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))// Train a DecisionTree model.
val dt = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")// Convert indexed labels back to original labels.
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)// Make predictions.
val predictions = model.transform(testData)// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")

输出结果:

+--------------+-----+--------------------+
|predictedLabel|label|            features|
+--------------+-----+--------------------+
|           0.0|  0.0|(692,[95,96,97,12...|
|           0.0|  0.0|(692,[122,123,124...|
|           0.0|  0.0|(692,[122,123,148...|
|           0.0|  0.0|(692,[126,127,128...|
|           0.0|  0.0|(692,[126,127,128...|
+--------------+-----+--------------------+
only showing top 5 rowsTest Error = 0.040000000000000036

欢迎关注微信公众号

这篇关于Spark ML机器学习库评估指标示例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1009351

相关文章

前端CSS Grid 布局示例详解

《前端CSSGrid布局示例详解》CSSGrid是一种二维布局系统,可以同时控制行和列,相比Flex(一维布局),更适合用在整体页面布局或复杂模块结构中,:本文主要介绍前端CSSGri... 目录css Grid 布局详解(通俗易懂版)一、概述二、基础概念三、创建 Grid 容器四、定义网格行和列五、设置行

Node.js 数据库 CRUD 项目示例详解(完美解决方案)

《Node.js数据库CRUD项目示例详解(完美解决方案)》:本文主要介绍Node.js数据库CRUD项目示例详解(完美解决方案),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考... 目录项目结构1. 初始化项目2. 配置数据库连接 (config/db.js)3. 创建模型 (models/

使用Python实现全能手机虚拟键盘的示例代码

《使用Python实现全能手机虚拟键盘的示例代码》在数字化办公时代,你是否遇到过这样的场景:会议室投影电脑突然键盘失灵、躺在沙发上想远程控制书房电脑、或者需要给长辈远程协助操作?今天我要分享的Pyth... 目录一、项目概述:不止于键盘的远程控制方案1.1 创新价值1.2 技术栈全景二、需求实现步骤一、需求

Spring LDAP目录服务的使用示例

《SpringLDAP目录服务的使用示例》本文主要介绍了SpringLDAP目录服务的使用示例... 目录引言一、Spring LDAP基础二、LdapTemplate详解三、LDAP对象映射四、基本LDAP操作4.1 查询操作4.2 添加操作4.3 修改操作4.4 删除操作五、认证与授权六、高级特性与最佳

CSS will-change 属性示例详解

《CSSwill-change属性示例详解》will-change是一个CSS属性,用于告诉浏览器某个元素在未来可能会发生哪些变化,本文给大家介绍CSSwill-change属性详解,感... will-change 是一个 css 属性,用于告诉浏览器某个元素在未来可能会发生哪些变化。这可以帮助浏览器优化

C++中std::distance使用方法示例

《C++中std::distance使用方法示例》std::distance是C++标准库中的一个函数,用于计算两个迭代器之间的距离,本文主要介绍了C++中std::distance使用方法示例,具... 目录语法使用方式解释示例输出:其他说明:总结std::distance&n编程bsp;是 C++ 标准

前端高级CSS用法示例详解

《前端高级CSS用法示例详解》在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交互和动态效果的关键技术之一,随着前端技术的不断发展,CSS的用法也日益丰富和高级,本文将深... 前端高级css用法在前端开发中,CSS(层叠样式表)不仅是用来控制网页的外观和布局,更是实现复杂交

C#使用SQLite进行大数据量高效处理的代码示例

《C#使用SQLite进行大数据量高效处理的代码示例》在软件开发中,高效处理大数据量是一个常见且具有挑战性的任务,SQLite因其零配置、嵌入式、跨平台的特性,成为许多开发者的首选数据库,本文将深入探... 目录前言准备工作数据实体核心技术批量插入:从乌龟到猎豹的蜕变分页查询:加载百万数据异步处理:拒绝界面

用js控制视频播放进度基本示例代码

《用js控制视频播放进度基本示例代码》写前端的时候,很多的时候是需要支持要网页视频播放的功能,下面这篇文章主要给大家介绍了关于用js控制视频播放进度的相关资料,文中通过代码介绍的非常详细,需要的朋友可... 目录前言html部分:JavaScript部分:注意:总结前言在javascript中控制视频播放

Java中StopWatch的使用示例详解

《Java中StopWatch的使用示例详解》stopWatch是org.springframework.util包下的一个工具类,使用它可直观的输出代码执行耗时,以及执行时间百分比,这篇文章主要介绍... 目录stopWatch 是org.springframework.util 包下的一个工具类,使用它