Skip to content

Commit a1e97ac

Browse files
committed
update model prediction
1 parent 9cd91e5 commit a1e97ac

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

src/main/scala/com/hyzs/spark/ml/ModelPrediction.scala

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import org.apache.spark.mllib.util.MLUtils
1515
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
1616
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
1717
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
18+
import org.apache.spark.mllib.evaluation.RegressionMetrics
1819
/**
1920
* Created by xk on 2018/10/26.
2021
*/
@@ -82,7 +83,8 @@ object ModelPrediction {
8283
// goal should be "Classification" or "Regression"
8384
val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams(goal)
8485
boostingStrategy.setNumIterations(100) // Note: Use more iterations in practice. eg. 10, 20
85-
boostingStrategy.setLearningRate(0.005)
86+
boostingStrategy.setLearningRate(0.001)
87+
boostingStrategy.setValidationTol(0.000001)
8688
//boostingStrategy.treeStrategy.setNumClasses(2)
8789
boostingStrategy.treeStrategy.setMaxDepth(5)
8890
boostingStrategy.treeStrategy.setImpurity(Variance)
@@ -93,8 +95,8 @@ object ModelPrediction {
9395
//boostingStrategy.treeStrategy.setImpurity(Gini)
9496

9597
// without validation
96-
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
97-
//val model = new GradientBoostedTrees(boostingStrategy).runWithValidation(trainingData, validData)
98+
//val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
99+
val model = new GradientBoostedTrees(boostingStrategy).runWithValidation(trainingData, validData)
98100

99101
if(goal == "Classification"){
100102
val predAndLabels = testData.map { point =>
@@ -107,14 +109,23 @@ object ModelPrediction {
107109
println("model accuracy: " + confusion.accuracy)
108110
println("model f1: " + confusion.f1_score)
109111
} else if (goal == "Regression"){
110-
val labelsAndPredictions = testData.map { point =>
112+
val trainPreds = trainingData.map{ point =>
111113
val prediction = model.predict(point.features)
112-
(point.label, prediction)
114+
(prediction, point.label)
113115
}
114-
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
116+
val predAndLabels = testData.map { point =>
117+
val prediction = model.predict(point.features)
118+
(prediction, point.label)
119+
}
120+
121+
/* val testMSE = predAndLabels.map{ case (p, l) => math.pow(p-l, 2) }.mean()
115122
val rmse = math.sqrt(testMSE)
116-
println(s"Test Mean Squared Error = $testMSE")
117-
println(s"Root Mean Squared Error = $rmse")
123+
println(s"Root Mean Squared Error = $rmse")*/
124+
125+
val trainMetric = new RegressionMetrics(trainPreds)
126+
val testMetric= new RegressionMetrics(predAndLabels)
127+
println(s"train RMSE = ${trainMetric.rootMeanSquaredError}")
128+
println(s"test RMSE = ${testMetric.rootMeanSquaredError}")
118129
println(s"Learned regression tree model:\n ${model.toDebugString}")
119130
val modelPath = "/user/hyzs/model/gbt_regression"
120131
println(s"save model to $modelPath")
@@ -212,8 +223,8 @@ object ModelPrediction {
212223
xgboost_ml()*/
213224
val rawData = MLUtils
214225
.loadLibSVMFile(sc, "/user/hyzs/convert/train_result/train_result.libsvm")
215-
.randomSplit(Array(0.7, 0.3))
216-
GBT(rawData(0), null, rawData(1), "Regression")
226+
.randomSplit(Array(0.4, 0.2, 0.4))
227+
GBT(rawData(0), rawData(1), rawData(2), "Regression")
217228
predictModel()
218229

219230
}

0 commit comments

Comments
 (0)