@@ -141,15 +141,42 @@ object ModelPrediction {
141141 }
142142
143143
144+ def xgboost_ml (): Unit = {
145+ val data = spark.read.format(" libsvm" ).load(libsvmPath)
146+ val Array (trainingData, testData) = data.randomSplit(Array (0.6 , 0.4 ))
147+ val xgbParam = Map (" eta" -> 0.1f ,
148+ " max_depth" -> 6 ,
149+ " objective" -> " binary:logistic" ,
150+ " num_round" -> 10 )
144151
152+ val xgbClassifier = new XGBoostClassifier (xgbParam).
153+ setFeaturesCol(" features" ).
154+ setLabelCol(" label" )
145155
146- def main (args : Array [String ]): Unit = {
147- val (trainingData, validData, testData) = prepareData()
148- println(" random forest =======" )
149- val randomForestModel = randomForest(trainingData, validData, testData)
150- println(" gbt =======" )
151- val gbt = GBT (trainingData, validData, testData)
156+ val xgbClassificationModel = xgbClassifier.fit(trainingData)
157+ val predictions = xgbClassificationModel.transform(testData)
158+
159+ val predAndLabels = predictions.select(" prediction" , " label" )
160+ .map(row => (row.getDouble(0 ), row.getDouble(1 )))
161+ .rdd
162+ .collect()
163+ val confusion = new ConfusionMatrix (predAndLabels)
164+ println(" model precision: " + confusion.precision)
165+ println(" model recall: " + confusion.recall)
166+ println(" model accuracy: " + confusion.accuracy)
167+ println(" model f1: " + confusion.f1_score)
152168
169+ }
170+
171+
172+ def main (args : Array [String ]): Unit = {
173+ /* val (trainingData, validData, testData) = prepareData()
174+ println("random forest =======")
175+ val randomForestModel = randomForest(trainingData, validData, testData)
176+ println("gbt =======")
177+ val gbt = GBT(trainingData, validData, testData)*/
178+ println(" xgboost =======" )
179+ xgboost_ml()
153180
154181 }
155182}
0 commit comments