@@ -15,6 +15,7 @@ import org.apache.spark.mllib.util.MLUtils
1515import ml .dmlc .xgboost4j .scala .spark .XGBoostClassifier
1616import org .apache .spark .mllib .evaluation .BinaryClassificationMetrics
1717import org .apache .spark .mllib .tree .model .GradientBoostedTreesModel
18+ import org .apache .spark .mllib .evaluation .RegressionMetrics
1819/**
1920 * Created by xk on 2018/10/26.
2021 */
@@ -82,7 +83,8 @@ object ModelPrediction {
8283 // goal should be "Classification" or "Regression"
8384 val boostingStrategy : BoostingStrategy = BoostingStrategy .defaultParams(goal)
8485 boostingStrategy.setNumIterations(100 ) // Note: Use more iterations in practice. eg. 10, 20
85- boostingStrategy.setLearningRate(0.005 )
86+ boostingStrategy.setLearningRate(0.001 )
87+ boostingStrategy.setValidationTol(0.000001 )
8688 // boostingStrategy.treeStrategy.setNumClasses(2)
8789 boostingStrategy.treeStrategy.setMaxDepth(5 )
8890 boostingStrategy.treeStrategy.setImpurity(Variance )
@@ -93,8 +95,8 @@ object ModelPrediction {
9395 // boostingStrategy.treeStrategy.setImpurity(Gini)
9496
9597 // without validation
96- val model = GradientBoostedTrees .train(trainingData, boostingStrategy)
97- // val model = new GradientBoostedTrees(boostingStrategy).runWithValidation(trainingData, validData)
98+ // val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
99+ val model = new GradientBoostedTrees (boostingStrategy).runWithValidation(trainingData, validData)
98100
99101 if (goal == " Classification" ){
100102 val predAndLabels = testData.map { point =>
@@ -107,14 +109,23 @@ object ModelPrediction {
107109 println(" model accuracy: " + confusion.accuracy)
108110 println(" model f1: " + confusion.f1_score)
109111 } else if (goal == " Regression" ){
110- val labelsAndPredictions = testData .map { point =>
112+ val trainPreds = trainingData .map{ point =>
111113 val prediction = model.predict(point.features)
112- (point.label, prediction )
114+ (prediction, point.label)
113115 }
114- val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2 ) }.mean()
116+ val predAndLabels = testData.map { point =>
117+ val prediction = model.predict(point.features)
118+ (prediction, point.label)
119+ }
120+
121+ /* val testMSE = predAndLabels.map{ case (p, l) => math.pow(p-l, 2) }.mean()
115122 val rmse = math.sqrt(testMSE)
116- println(s " Test Mean Squared Error = $testMSE" )
117- println(s " Root Mean Squared Error = $rmse" )
123+ println(s"Root Mean Squared Error = $rmse")*/
124+
125+ val trainMetric = new RegressionMetrics (trainPreds)
126+ val testMetric = new RegressionMetrics (predAndLabels)
127+ println(s " train RMSE = ${trainMetric.rootMeanSquaredError}" )
128+ println(s " test RMSE = ${testMetric.rootMeanSquaredError}" )
118129 println(s " Learned regression tree model: \n ${model.toDebugString}" )
119130 val modelPath = " /user/hyzs/model/gbt_regression"
120131 println(s " save model to $modelPath" )
@@ -212,8 +223,8 @@ object ModelPrediction {
212223 xgboost_ml()*/
213224 val rawData = MLUtils
214225 .loadLibSVMFile(sc, " /user/hyzs/convert/train_result/train_result.libsvm" )
215- .randomSplit(Array (0.7 , 0.3 ))
216- GBT (rawData(0 ), null , rawData(1 ), " Regression" )
226+ .randomSplit(Array (0.4 , 0.2 , 0.4 ))
227+ GBT (rawData(0 ), rawData( 1 ) , rawData(2 ), " Regression" )
217228 predictModel()
218229
219230 }
0 commit comments