Skip to content

Commit 04f8cfd

Browse files
committed
update logistic regression
1 parent a4c7fb5 commit 04f8cfd

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

src/main/scala/com/hyzs/spark/ml/ConvertLibsvm_v2.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,13 @@ import com.hyzs.spark.utils.SparkUtils._
99
import com.hyzs.spark.utils.{BaseUtil, Params}
1010
import org.apache.spark.ml.Pipeline
1111
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorAssembler}
12-
import org.apache.spark.mllib.linalg.Vector
13-
import org.apache.spark.mllib.regression.LabeledPoint
14-
import org.apache.spark.mllib.util.MLUtils
1512
import org.apache.spark.rdd.RDD
1613
import org.apache.spark.sql._
1714
import org.apache.spark.sql.functions.{col, _}
1815
import org.apache.spark.sql.types._
19-
2016
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
2117
import com.hyzs.spark.sql.JDDataProcess
22-
import org.apache.spark.sql.expressions.UserDefinedFunction
18+
2319

2420
/**
2521
* Created by XIANGKUN on 2018/4/24.
@@ -171,6 +167,7 @@ object ConvertLibsvm_v2 {
171167
}
172168

173169

170+
174171
def convertLibsvm(args:Array[String]): Unit ={
175172
//TODO: switch libsvm with or without label table
176173
//val args = Array("train")

src/main/scala/com/hyzs/spark/ml/MatrixOpsInSpark.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package com.hyzs.spark.ml
33
import com.hyzs.spark.utils.SparkUtils._
44
import com.hyzs.spark.utils.BaseUtil._
55
import org.apache.spark.ml.clustering.KMeans
6+
import org.apache.spark.ml.feature.LabeledPoint
67
import org.apache.spark.ml.linalg.Vectors
78
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
89
import org.apache.spark.sql.types.{StructField, StructType}
910
import org.apache.spark.sql.{Dataset, Row}
11+
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
1012
/**
1113
* Created by xk on 2018/5/8.
1214
*/
@@ -47,5 +49,33 @@ object MatrixOpsInSpark {
4749
model.clusterCenters.foreach(println)
4850
}
4951

52+
def logisticRegressionTest(): Unit ={
53+
val training = spark.table("kddcup_vector")
54+
val lr = new LogisticRegression()
55+
.setMaxIter(10)
56+
.setRegParam(0.3)
57+
.setElasticNetParam(0.8)
58+
59+
// Fit the model
60+
val lrModel:LogisticRegressionModel = lr.fit(training)
61+
// Print the coefficients and intercept for logistic regression
62+
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
63+
64+
val trainingSummary = lrModel.summary
65+
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
66+
val roc = binarySummary.roc
67+
roc.show()
68+
println(s"areaUnderROC: ${binarySummary.areaUnderROC}")
69+
70+
}
71+
72+
def convertDataSetToLabeledPoint(dataSet:Dataset[Row]): Dataset[LabeledPoint] = {
73+
val labeled = dataSet.map{ row =>
74+
val datum:Array[Double] = row.toSeq.map(toDoubleDynamic).toArray
75+
val labeledPoint = LabeledPoint(datum(0), Vectors.dense(datum.drop(1)))
76+
labeledPoint
77+
}
78+
labeled
79+
}
5080

5181
}

0 commit comments

Comments
 (0)