本文主要是介绍SVM实例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
数据源:R自带的iris数据(R的e1071包没装下来,so用Scala写了;鸢尾花(iris)是数据挖掘常用到的一个数据集,包含150种鸢尾花的信息,每50种取自三个鸢尾花种之一(setosa,versicolour或virginica)。每个花的特征用下面的5种属性描述萼片长度(Sepal.Length)、萼片宽度(Sepal.Width)、花瓣长度(Petal.Length)、花瓣宽度(Petal.Width)、类(Species)。);
spark mlib代码:
def svmTest(sc:SparkContext,sqlContext:SQLContext): Unit ={import org.apache.spark.mllib.classification.SVMWithSGDimport org.apache.spark.mllib.regression.LabeledPointimport sqlContext.implicits._import breeze.linalg._// Load and parse the data file val data = sc.textFile("file:///D://cs3.txt")val trainData = data.map { line =>val parts = line.split(",")val y=parts(0).toDoubleval vd0=Vectors.dense(parts(1).toDouble,parts(2).toDouble,parts(3).toDouble,parts(4).toDouble)val v1=Vectors.dense(-2.0919917512589015,7.089178225784549,5.567376955110936,0.8621925858604499) // println(parts(1).toDouble*(-2.0919917512589015)+parts(2).toDouble*(7.089178225784549)+parts(3).toDouble*(5.567376955110936)+parts(4).toDouble*(0.8621925858604499) ) LabeledPoint(y,vd0)}val testData = sc.textFile("file:///D://cs4.txt").map { line =>val parts = line.split(",")val y=parts(0).toDoubleval vd0=Vectors.dense(parts(1).toDouble,parts(2).toDouble,parts(3).toDouble,parts(4).toDouble)LabeledPoint(y,vd0)} // Run training algorithm to build the model val numIterations = 20 val model = SVMWithSGD.train(trainData, numIterations)// Evaluate model on training examples and compute training error val trainLabelPreds = trainData.map { point =>val prediction = model.predict(point.features)(point.label, prediction)}println(model.toString())trainLabelPreds.toDF("label","prediction").showval trainErr = trainLabelPreds.filter(r => r._1 != r._2).count.toDouble / trainData.countprintln("Training Error = " + trainErr)// Compute raw scores on the test set. val testPredictLabels = testData.map { point => // println("feature="+point.features) val score = model.predict(point.features)(score, point.label)} // testPredictLabels.collect.foreach(println) val testErr = testPredictLabels.filter(r => r._1 != r._2).count.toDouble / testData.countprintln("test Error = " + testErr)}
运行结果(错误率0,准确率100%):
其实可以加个打印看下wx+b的值(然后发现wx<0的为label=0,大于0的为label=1,b似乎没有?松弛变量的常数C也没见到?)
println(parts(1).toDouble*(0.4672035760731836)+parts(2).toDouble*(1.6471825085309382)+parts(3).toDouble*(-2.317158274255798)+parts(4).toDouble*(-0.978114663957106) )
附:
traindata 训练集cs3.txt(格式:label标签[只能是0/1],特征1......特征4):
1,5.1,3.5,1.4,0.2
1,4.9,3,1.4,0.2
1,4.7,3.2,1.3,0.2
1,4.6,3.1,1.5,0.2
1,5,3.6,1.4,0.2
1,5.4,3.9,1.7,0.4
1,4.6,3.4,1.4,0.3
1,5,3.4,1.5,0.2
1,4.4,2.9,1.4,0.2
1,4.9,3.1,1.5,0.1
1,5.4,3.7,1.5,0.2
1,4.8,3.4,1.6,0.2
1,4.8,3,1.4,0.1
1,4.3,3,1.1,0.1
1,5.8,4,1.2,0.2
1,5.7,4.4,1.5,0.4
1,5.4,3.9,1.3,0.4
1,5.1,3.5,1.4,0.3
1,5.7,3.8,1.7,0.3
1,5.1,3.8,1.5,0.3
1,5.4,3.4,1.7,0.2
1,5.1,3.7,1.5,0.4
1,4.6,3.6,1,0.2
1,5.1,3.3,1.7,0.5
1,4.8,3.4,1.9,0.2
0,7,3.2,4.7,1.4
0,6.4,3.2,4.5,1.5
0,6.9,3.1,4.9,1.5
0,5.5,2.3,4,1.3
0,6.5,2.8,4.6,1.5
0,5.7,2.8,4.5,1.3
0,6.3,3.3,4.7,1.6
0,4.9,2.4,3.3,1
0,6.6,2.9,4.6,1.3
0,5.2,2.7,3.9,1.4
0,5,2,3.5,1
0,5.9,3,4.2,1.5
0,6,2.2,4,1
0,6.1,2.9,4.7,1.4
0,5.6,2.9,3.6,1.3
0,6.7,3.1,4.4,1.4
0,5.6,3,4.5,1.5
0,5.8,2.7,4.1,1
0,6.2,2.2,4.5,1.5
0,5.6,2.5,3.9,1.1
0,5.9,3.2,4.8,1.8
0,6.1,2.8,4,1.3
0,6.3,2.5,4.9,1.5
0,6.1,2.8,4.7,1.2
0,6.4,2.9,4.3,1.3
testdata 测试集(cs4.txt):
0,6.6,3,4.4,1.4
0,6.8,2.8,4.8,1.4
0,6.7,3,5,1.7
0,6,2.9,4.5,1.5
0,5.7,2.6,3.5,1
0,5.5,2.4,3.8,1.1
0,5.5,2.4,3.7,1
0,5.8,2.7,3.9,1.2
0,6,2.7,5.1,1.6
0,5.4,3,4.5,1.5
0,6,3.4,4.5,1.6
0,6.7,3.1,4.7,1.5
0,6.3,2.3,4.4,1.3
0,5.6,3,4.1,1.3
0,5.5,2.5,4,1.3
0,5.5,2.6,4.4,1.2
0,6.1,3,4.6,1.4
0,5.8,2.6,4,1.2
0,5,2.3,3.3,1
0,5.6,2.7,4.2,1.3
0,5.7,3,4.2,1.2
0,5.7,2.9,4.2,1.3
0,6.2,2.9,4.3,1.3
0,5.1,2.5,3,1.1
0,5.7,2.8,4.1,1.3
1,5,3,1.6,0.2
1,5,3.4,1.6,0.4
1,5.2,3.5,1.5,0.2
1,5.2,3.4,1.4,0.2
1,4.7,3.2,1.6,0.2
1,4.8,3.1,1.6,0.2
1,5.4,3.4,1.5,0.4
1,5.2,4.1,1.5,0.1
1,5.5,4.2,1.4,0.2
1,4.9,3.1,1.5,0.2
1,5,3.2,1.2,0.2
1,5.5,3.5,1.3,0.2
1,4.9,3.6,1.4,0.1
1,4.4,3,1.3,0.2
1,5.1,3.4,1.5,0.2
1,5,3.5,1.3,0.3
1,4.5,2.3,1.3,0.3
1,4.4,3.2,1.3,0.2
1,5,3.5,1.6,0.6
1,5.1,3.8,1.9,0.4
1,4.8,3,1.4,0.3
1,5.1,3.8,1.6,0.2
1,4.6,3.2,1.4,0.2
1,5.3,3.7,1.5,0.2
1,5,3.3,1.4,0.2
这篇关于SVM实例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!