11package  com .hyzs .spark .ml 
22
3+ import  com .hyzs .spark .ml .ConvertLibsvm .saveRdd 
34import  com .hyzs .spark .mllib .evaluation .ConfusionMatrix 
45import  com .hyzs .spark .utils .SparkUtils ._ 
56import  org .apache .spark .ml .classification .GBTClassifier 
@@ -9,10 +10,11 @@ import org.apache.spark.mllib.regression.LabeledPoint
910import  org .apache .spark .mllib .tree .{GradientBoostedTrees , RandomForest }
1011import  org .apache .spark .rdd .RDD 
1112import  org .apache .spark .mllib .tree .configuration .BoostingStrategy 
12- import  org .apache .spark .mllib .tree .impurity .{Entropy , Gini }
13+ import  org .apache .spark .mllib .tree .impurity .{Entropy , Gini ,  Variance }
1314import  org .apache .spark .mllib .util .MLUtils 
1415import  ml .dmlc .xgboost4j .scala .spark .XGBoostClassifier 
1516import  org .apache .spark .mllib .evaluation .BinaryClassificationMetrics 
17+ import  org .apache .spark .mllib .tree .model .GradientBoostedTreesModel 
1618/** 
1719 * Created by xk on 2018/10/26. 
1820 */  
@@ -71,38 +73,59 @@ object ModelPrediction {
7173
7274 def  GBT (trainingData : RDD [LabeledPoint ],
7375 validData : RDD [LabeledPoint ],
74-  testData: RDD [LabeledPoint ]):  Unit  =  {
76+  testData: RDD [LabeledPoint ],
77+  goal: String ):  Unit  =  {
78+ 
7579
7680 //  Train a GradientBoostedTrees model.
7781 //  The defaultParams for Classification use LogLoss by default.
78-  val  boostingStrategy :  BoostingStrategy  =  BoostingStrategy .defaultParams(" Classification" 
79-  boostingStrategy.setNumIterations(10 ) //  Note: Use more iterations in practice. eg. 10, 20
80-  boostingStrategy.treeStrategy.setNumClasses(2 )
81-  boostingStrategy.treeStrategy.setMaxDepth(6 )
82-  //  boostingStrategy.treeStrategy.setMaxBins(32)
82+  //  goal should be "Classification" or "Regression"
83+  val  boostingStrategy :  BoostingStrategy  =  BoostingStrategy .defaultParams(goal)
84+  boostingStrategy.setNumIterations(100 ) //  Note: Use more iterations in practice. eg. 10, 20
85+  boostingStrategy.setLearningRate(0.005 )
86+  // boostingStrategy.treeStrategy.setNumClasses(2)
87+  boostingStrategy.treeStrategy.setMaxDepth(5 )
88+  boostingStrategy.treeStrategy.setImpurity(Variance )
89+  boostingStrategy.treeStrategy.setMaxBins(32 )
8390 //  Empty categoricalFeaturesInfo indicates all features are continuous.
8491 //  boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
8592 // boostingStrategy.treeStrategy.setImpurity(Entropy)
8693 // boostingStrategy.treeStrategy.setImpurity(Gini)
8794
8895 //  without validation
89-  //  val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
90-  val  model  =  new  GradientBoostedTrees (boostingStrategy).runWithValidation(trainingData, validData)
91- 
92-  val  predAndLabels  =  testData.map { point => 
93-  val  prediction  =  model.predict(point.features)
94-  (prediction, point.label)
95-  }.collect()
96-  val  confusion  =  new  ConfusionMatrix (predAndLabels)
96+  val  model  =  GradientBoostedTrees .train(trainingData, boostingStrategy)
97+  // val model = new GradientBoostedTrees(boostingStrategy).runWithValidation(trainingData, validData)
98+ 
99+  if (goal ==  " Classification" 
100+  val  predAndLabels  =  testData.map { point => 
101+  val  prediction  =  model.predict(point.features)
102+  (prediction, point.label)
103+  }.collect()
104+  val  confusion  =  new  ConfusionMatrix (predAndLabels)
105+  println(" model precision: " +  confusion.precision)
106+  println(" model recall: " +  confusion.recall)
107+  println(" model accuracy: " +  confusion.accuracy)
108+  println(" model f1: " +  confusion.f1_score)
109+  } else  if  (goal ==  " Regression" 
110+  val  labelsAndPredictions  =  testData.map { point => 
111+  val  prediction  =  model.predict(point.features)
112+  (point.label, prediction)
113+  }
114+  val  testMSE  =  labelsAndPredictions.map{ case  (v, p) =>  math.pow(v -  p, 2 ) }.mean()
115+  val  rmse  =  math.sqrt(testMSE)
116+  println(s " Test Mean Squared Error =  $testMSE" )
117+  println(s " Root Mean Squared Error =  $rmse" )
118+  println(s " Learned regression tree model: \n   ${model.toDebugString}" )
119+  val  modelPath  =  " /user/hyzs/model/gbt_regression" 
120+  println(s " save model to  $modelPath" )
121+  if (checkHDFileExist(modelPath)) dropHDFiles(modelPath)
122+  model.save(sc, modelPath)
123+  } else  throw  new  IllegalArgumentException (s " $goal is not supported by boosting. " )
97124
98-  println(" model precision: " +  confusion.precision)
99-  println(" model recall: " +  confusion.recall)
100-  println(" model accuracy: " +  confusion.accuracy)
101-  println(" model f1: " +  confusion.f1_score)
102125 }
103126
104127
105-  def  GBT_ml ():  Unit  =  {
128+  def  GBT_classifier ():  Unit  =  {
106129 val  data  =  spark.read.format(" libsvm" 
107130 val  Array (trainingData, testData) =  data.randomSplit(Array (0.6 , 0.4 ))
108131
@@ -140,7 +163,6 @@ object ModelPrediction {
140163
141164 }
142165
143- 
144166 def  xgboost_ml ():  Unit  =  {
145167 val  data  =  spark.read.format(" libsvm" 
146168 val  Array (trainingData, testData) =  data.randomSplit(Array (0.6 , 0.4 ))
@@ -168,15 +190,31 @@ object ModelPrediction {
168190
169191 }
170192
193+  def  predictModel ():  Unit  =  {
194+  val  modelPath  =  " /user/hyzs/model/gbt_regression" 
195+  val  dataPath  =  " /user/hyzs/convert/test_result/test_result.libsvm" 
196+  val  model  =  GradientBoostedTreesModel .load(sc, modelPath)
197+  val  testData  =  MLUtils .loadLibSVMFile(sc, dataPath)
198+  val  preds  =  testData.map{ record => 
199+  model.predict(record.features).toString
200+  }
201+  saveRdd(preds, " /user/hyzs/convert/test_result/preds.txt" 
202+ 
203+  }
171204
172205 def  main (args : Array [String ]):  Unit  =  {
173206 /*  val (trainingData, validData, testData) = prepareData()
174207 println("random forest =======") 
175208 val randomForestModel = randomForest(trainingData, validData, testData) 
176209 println("gbt =======") 
177210 val gbt = GBT(trainingData, validData, testData)*/  
178-  println(" xgboost =======" 
179-  xgboost_ml()
211+ /*  println("xgboost =======")
212+  xgboost_ml()*/  
213+  val  rawData  =  MLUtils 
214+  .loadLibSVMFile(sc, " /user/hyzs/convert/train_result/train_result.libsvm" 
215+  .randomSplit(Array (0.7 , 0.3 ))
216+  GBT (rawData(0 ), null , rawData(1 ), " Regression" 
217+  predictModel()
180218
181219 }
182220}
0 commit comments