Skip to content

Commit 15140ca

Browse files
committed
update XGBoost prediction
1 parent 487455c commit 15140ca

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

pom.xml

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@
224224
<target>1.8</target>
225225
</configuration>
226226
</plugin>
227-
228227
<plugin>
229228
<groupId>org.apache.maven.plugins</groupId>
230229
<artifactId>maven-surefire-plugin</artifactId>
@@ -233,6 +232,30 @@
233232
<skipTests>true</skipTests>
234233
</configuration>
235234
</plugin>
235+
<!-- <plugin>
236+
<artifactId>maven-assembly-plugin</artifactId>
237+
<configuration>
238+
<appendAssemblyId>false</appendAssemblyId>
239+
<descriptorRefs>
240+
<descriptorRef>jar-with-dependencies</descriptorRef>
241+
</descriptorRefs>
242+
<archive>
243+
<manifest>
244+
&lt;!&ndash; 此处指定main方法入口的class &ndash;&gt;
245+
<mainClass>com.hyzs.spark.ml.ModelEvaluation</mainClass>
246+
</manifest>
247+
</archive>
248+
</configuration>
249+
<executions>
250+
<execution>
251+
<id>make-assembly</id>
252+
<phase>package</phase>
253+
<goals>
254+
<goal>assembly</goal>
255+
</goals>
256+
</execution>
257+
</executions>
258+
</plugin>-->
236259
</plugins>
237260
<resources>
238261
<resource>

src/main/scala/com/hyzs/spark/ml/ModelPrediction.scala

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,42 @@ object ModelPrediction {
141141
}
142142

143143

144+
def xgboost_ml(): Unit = {
145+
val data = spark.read.format("libsvm").load(libsvmPath)
146+
val Array(trainingData, testData) = data.randomSplit(Array(0.6, 0.4))
147+
val xgbParam = Map("eta" -> 0.1f,
148+
"max_depth" -> 6,
149+
"objective" -> "binary:logistic",
150+
"num_round" -> 10)
144151

152+
val xgbClassifier = new XGBoostClassifier(xgbParam).
153+
setFeaturesCol("features").
154+
setLabelCol("label")
145155

146-
def main(args: Array[String]): Unit = {
147-
val (trainingData, validData, testData) = prepareData()
148-
println("random forest =======")
149-
val randomForestModel = randomForest(trainingData, validData, testData)
150-
println("gbt =======")
151-
val gbt = GBT(trainingData, validData, testData)
156+
val xgbClassificationModel = xgbClassifier.fit(trainingData)
157+
val predictions = xgbClassificationModel.transform(testData)
158+
159+
val predAndLabels = predictions.select("prediction", "label")
160+
.map(row => (row.getDouble(0), row.getDouble(1)))
161+
.rdd
162+
.collect()
163+
val confusion = new ConfusionMatrix(predAndLabels)
164+
println("model precision: " + confusion.precision)
165+
println("model recall: " + confusion.recall)
166+
println("model accuracy: " + confusion.accuracy)
167+
println("model f1: " + confusion.f1_score)
152168

169+
}
170+
171+
172+
def main(args: Array[String]): Unit = {
173+
/* val (trainingData, validData, testData) = prepareData()
174+
println("random forest =======")
175+
val randomForestModel = randomForest(trainingData, validData, testData)
176+
println("gbt =======")
177+
val gbt = GBT(trainingData, validData, testData)*/
178+
println("xgboost =======")
179+
xgboost_ml()
153180

154181
}
155182
}

0 commit comments

Comments
 (0)