Skip to content

Commit a0392ab

Browse files
committed
update model prediction
1 parent 7c67ed4 commit a0392ab

File tree

4 files changed

+125
-71
lines changed

4 files changed

+125
-71
lines changed

src/main/scala/com/hyzs/spark/ml/ConvertLibsvm.scala

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ object ConvertLibsvm {
2626
val originalKey = "user_id"
2727
//val key = "user_id_md5"
2828
val key = "card_id"
29+
val label = "target"
2930
val maxLabelMapLength = 100
3031
val convertPath = "/user/hyzs/convert/"
3132

@@ -224,25 +225,26 @@ object ConvertLibsvm {
224225
val stringSchema = dataSchema.filter(field => field.dataType == StringType)
225226
val stringCols = stringSchema.map(field => field.name)
226227
val indexerArray = stringCols.map(field => getIndexers(dataSet, field))
227-
val objectArray = buildObjectArray(Array(key, "target"), dataSchema, indexerArray)
228+
val objectArray = buildObjectArray(Array(key, label), dataSchema, indexerArray)
228229
val objRdd:RDD[String] = buildObjectJsonRdd(objectArray)
229230
val nameRdd = sc.makeRDD[String](dataSet.columns)
230231

231232
println(s"start save obj: ${objRdd.first()} ...")
232233
saveRdd(objRdd, s"$taskPath$tableName.obj")
233234
println(s"start save name: ${nameRdd.first()} ...")
234235
saveRdd(nameRdd, s"$taskPath$tableName.name")
235-
println(s"merge obj files to $taskPath$tableName.obj.")
236236
objectArray
237237
}
238238

239239
// dataSet id, target, feature1, feature2, ...
240-
def convertLibsvmFromDataSet(dataSet:Dataset[Row]): Unit = {
241-
val tableName = "tmpResult"
240+
def convertLibsvmFromDataSet(dataSet:Dataset[Row], tableName:String, objs:Array[ModelObject]=null): Unit = {
242241
val taskPath = s"$convertPath$tableName/"
243-
244-
val sourceData = processNull(dataSet)
245-
val objectArray = trainObjectArray(sourceData, tableName)
242+
val processedData = processZeroValue(dataSet)
243+
val sourceData = processNull(processedData)
244+
var objectArray:Array[ModelObject] = null
245+
if(objs == null){
246+
objectArray = trainObjectArray(sourceData, tableName)
247+
} else objectArray = objs
246248
val libsvm_result = replaceOldCols(sourceData, objectArray)
247249
val indexRdd:RDD[String] = libsvm_result.select(key).rdd.map(row => row(0).toString)
248250
val libsvmRdd: RDD[String] = libsvm_result.rdd.map(row => {
@@ -251,13 +253,19 @@ object ConvertLibsvm {
251253
})
252254
saveRdd(indexRdd, s"$taskPath$tableName.index")
253255
saveRdd(libsvmRdd, s"$taskPath$tableName.libsvm")
254-
255256
}
256257

257-
def main(args: Array[String]): Unit = {
258-
val data = spark.table("merchant.tmpResult")
259-
convertLibsvmFromDataSet(data)
260258

259+
def main(args: Array[String]): Unit = {
260+
val trainName = "train_result"
261+
val taskPath = s"$convertPath$trainName/"
262+
val trainData = spark.table(s"merchant.$trainName")
263+
convertLibsvmFromDataSet(trainData, trainName)
264+
265+
val testName = "test_result"
266+
val testData =spark.table(s"merchant.$testName")
267+
val objs = readObj(s"$taskPath$trainName.obj")
268+
convertLibsvmFromDataSet(testData, testName, objs)
261269
}
262270
}
263271

src/main/scala/com/hyzs/spark/ml/ModelPrediction.scala

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.hyzs.spark.ml
22

3+
import com.hyzs.spark.ml.ConvertLibsvm.saveRdd
34
import com.hyzs.spark.mllib.evaluation.ConfusionMatrix
45
import com.hyzs.spark.utils.SparkUtils._
56
import org.apache.spark.ml.classification.GBTClassifier
@@ -9,10 +10,11 @@ import org.apache.spark.mllib.regression.LabeledPoint
910
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest}
1011
import org.apache.spark.rdd.RDD
1112
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
12-
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
13+
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
1314
import org.apache.spark.mllib.util.MLUtils
1415
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
1516
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
17+
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
1618
/**
1719
* Created by xk on 2018/10/26.
1820
*/
@@ -71,38 +73,59 @@ object ModelPrediction {
7173

7274
def GBT(trainingData: RDD[LabeledPoint],
7375
validData: RDD[LabeledPoint],
74-
testData:RDD[LabeledPoint]): Unit = {
76+
testData:RDD[LabeledPoint],
77+
goal:String): Unit = {
78+
7579

7680
// Train a GradientBoostedTrees model.
7781
// The defaultParams for Classification use LogLoss by default.
78-
val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams("Classification")
79-
boostingStrategy.setNumIterations(10) // Note: Use more iterations in practice. eg. 10, 20
80-
boostingStrategy.treeStrategy.setNumClasses(2)
81-
boostingStrategy.treeStrategy.setMaxDepth(6)
82-
// boostingStrategy.treeStrategy.setMaxBins(32)
82+
// goal should be "Classification" or "Regression"
83+
val boostingStrategy: BoostingStrategy = BoostingStrategy.defaultParams(goal)
84+
boostingStrategy.setNumIterations(100) // Note: Use more iterations in practice. eg. 10, 20
85+
boostingStrategy.setLearningRate(0.005)
86+
//boostingStrategy.treeStrategy.setNumClasses(2)
87+
boostingStrategy.treeStrategy.setMaxDepth(5)
88+
boostingStrategy.treeStrategy.setImpurity(Variance)
89+
boostingStrategy.treeStrategy.setMaxBins(32)
8390
// Empty categoricalFeaturesInfo indicates all features are continuous.
8491
// boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
8592
//boostingStrategy.treeStrategy.setImpurity(Entropy)
8693
//boostingStrategy.treeStrategy.setImpurity(Gini)
8794

8895
// without validation
89-
// val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
90-
val model = new GradientBoostedTrees(boostingStrategy).runWithValidation(trainingData, validData)
91-
92-
val predAndLabels = testData.map { point =>
93-
val prediction = model.predict(point.features)
94-
(prediction, point.label)
95-
}.collect()
96-
val confusion = new ConfusionMatrix(predAndLabels)
96+
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
97+
//val model = new GradientBoostedTrees(boostingStrategy).runWithValidation(trainingData, validData)
98+
99+
if(goal == "Classification"){
100+
val predAndLabels = testData.map { point =>
101+
val prediction = model.predict(point.features)
102+
(prediction, point.label)
103+
}.collect()
104+
val confusion = new ConfusionMatrix(predAndLabels)
105+
println("model precision: " + confusion.precision)
106+
println("model recall: " + confusion.recall)
107+
println("model accuracy: " + confusion.accuracy)
108+
println("model f1: " + confusion.f1_score)
109+
} else if (goal == "Regression"){
110+
val labelsAndPredictions = testData.map { point =>
111+
val prediction = model.predict(point.features)
112+
(point.label, prediction)
113+
}
114+
val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean()
115+
val rmse = math.sqrt(testMSE)
116+
println(s"Test Mean Squared Error = $testMSE")
117+
println(s"Root Mean Squared Error = $rmse")
118+
println(s"Learned regression tree model:\n ${model.toDebugString}")
119+
val modelPath = "/user/hyzs/model/gbt_regression"
120+
println(s"save model to $modelPath")
121+
if(checkHDFileExist(modelPath)) dropHDFiles(modelPath)
122+
model.save(sc, modelPath)
123+
} else throw new IllegalArgumentException(s"$goal is not supported by boosting.")
97124

98-
println("model precision: " + confusion.precision)
99-
println("model recall: " + confusion.recall)
100-
println("model accuracy: " + confusion.accuracy)
101-
println("model f1: " + confusion.f1_score)
102125
}
103126

104127

105-
def GBT_ml(): Unit = {
128+
def GBT_classifier(): Unit = {
106129
val data = spark.read.format("libsvm").load(libsvmPath)
107130
val Array(trainingData, testData) = data.randomSplit(Array(0.6, 0.4))
108131

@@ -140,7 +163,6 @@ object ModelPrediction {
140163

141164
}
142165

143-
144166
def xgboost_ml(): Unit = {
145167
val data = spark.read.format("libsvm").load(libsvmPath)
146168
val Array(trainingData, testData) = data.randomSplit(Array(0.6, 0.4))
@@ -168,15 +190,31 @@ object ModelPrediction {
168190

169191
}
170192

193+
def predictModel(): Unit = {
194+
val modelPath = "/user/hyzs/model/gbt_regression"
195+
val dataPath = "/user/hyzs/convert/test_result/test_result.libsvm"
196+
val model = GradientBoostedTreesModel.load(sc, modelPath)
197+
val testData = MLUtils.loadLibSVMFile(sc, dataPath)
198+
val preds = testData.map{ record =>
199+
model.predict(record.features).toString
200+
}
201+
saveRdd(preds, "/user/hyzs/convert/test_result/preds.txt")
202+
203+
}
171204

172205
def main(args: Array[String]): Unit = {
173206
/* val (trainingData, validData, testData) = prepareData()
174207
println("random forest =======")
175208
val randomForestModel = randomForest(trainingData, validData, testData)
176209
println("gbt =======")
177210
val gbt = GBT(trainingData, validData, testData)*/
178-
println("xgboost =======")
179-
xgboost_ml()
211+
/* println("xgboost =======")
212+
xgboost_ml()*/
213+
val rawData = MLUtils
214+
.loadLibSVMFile(sc, "/user/hyzs/convert/train_result/train_result.libsvm")
215+
.randomSplit(Array(0.7, 0.3))
216+
GBT(rawData(0), null, rawData(1), "Regression")
217+
predictModel()
180218

181219
}
182220
}

src/main/scala/com/hyzs/spark/sql/NewDataProcess.scala

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,52 +137,53 @@ object NewDataProcess {
137137
)
138138
}
139139

140-
def merchantProcess(): Unit = {
140+
def transProcess(): Unit = {
141141
val keyColumn = "card_id"
142-
var trainTable = spark.table("merchant.train")
143142
val newTrans = spark.table("merchant.new_merchant_transactions")
143+
var ids = newTrans.select(keyColumn).distinct()
144144
val processMode = Seq("city_id", "category_1", "installments", "category_3",
145145
"merchant_category_id", "category_2", "state_id", "subsector_id")
146146
for(colName <- processMode){
147147
val modeTmpTable = getColumnMode(newTrans, keyColumn, colName)
148-
trainTable = trainTable.join(modeTmpTable, Seq(keyColumn), "left")
148+
ids = ids.join(modeTmpTable, Seq(keyColumn), "left")
149149
}
150-
saveTable(trainTable, "train_result", "merchant")
150+
val aggTable = getColumnAgg(newTrans, keyColumn, "purchase_amount")
151+
ids = ids.join(aggTable, Seq(keyColumn), "left")
152+
saveTable(ids, "new_transactions_processed", "merchant")
153+
}
151154

155+
def hisProcess(): Unit = {
156+
val keyColumn = "card_id"
157+
val trans = spark.table("merchant.historical_transactions")
158+
var ids = trans.select(keyColumn).distinct()
159+
val processMode = Seq("city_id", "category_1", "installments", "category_3",
160+
"merchant_category_id", "category_2", "state_id", "subsector_id")
161+
for(colName <- processMode){
162+
val modeTmpTable = getColumnMode(trans, keyColumn, colName)
163+
ids = ids.join(modeTmpTable, Seq(keyColumn), "left")
164+
}
165+
val aggTable = getColumnAgg(trans, keyColumn, "purchase_amount")
166+
ids = ids.join(aggTable, Seq(keyColumn), "left")
167+
saveTable(ids, "historical_transactions_processed", "merchant")
168+
}
152169

170+
def merchantProcess(): Unit = {
171+
val keyColumn = "card_id"
172+
var trainTable = spark.table("merchant.train")
173+
.select("card_id", "target", "feature_1", "feature_2", "feature_3")
174+
var testTable = spark.table("merchant.test").withColumn("target", lit(0))
175+
.select("card_id", "target", "feature_1", "feature_2", "feature_3")
176+
val transTable = spark.table("merchant.new_transactions_processed")
153177

178+
trainTable = trainTable.join(transTable, Seq(keyColumn), "left")
179+
saveTable(trainTable, "train_result", "merchant")
180+
testTable = testTable.join(transTable, Seq(keyColumn), "left")
181+
saveTable(testTable, "test_result", "merchant")
154182
}
155183

156184
def main(args: Array[String]): Unit = {
157-
//preprocessOrder()
158-
val key = "id"
159-
val table = spark.table("sample_n_enc")
160-
val sample1 = spark.table("sample_w_enc")
161-
val rowTable = columnToRow(table)
162-
saveTable(rowTable, "sample_2")
163-
164-
val sample_features = sample1.join(table, Seq("id"), "left")
165-
saveTable(sample_features, "sample_features")
166-
167-
val all = spark.table("hyzs.all_data")
168-
169-
//val diffCols = all.columns diff sample_features.columns
170-
171-
val order = spark.table("sample_order")
172-
//spark.sql("drop table sample_fix")
173-
//spark.sql("create table sample_fix(id string, brs_brs_p0001308 string, mkt_schd_p0001328 string, mkt_schd_p0001327 string)")
174-
175-
val fix = spark.table("sample_fix")
176-
val features = sample_features.join(fix, Seq("id"), "left")
177-
.selectExpr("id"+:(all.columns diff Seq("user_id", "user_id_md5")): _*)
178-
saveTable(features, "features")
179-
val sample_all = order.join(features, Seq("id"), "right")
180-
saveTable(sample_all, "sample_all")
181-
sample_all
182-
.coalesce(1)
183-
.write.format("com.databricks.spark.csv")
184-
.option("header", "true")
185-
.save("/hyzs/test_data/sample_all.csv")
185+
transProcess()
186+
merchantProcess()
186187
}
187188

188189
}

src/main/scala/com/hyzs/spark/utils/SparkUtils.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ object SparkUtils {
3636
val hdConf: Configuration = sc.hadoopConfiguration
3737
val fs: FileSystem = FileSystem.get(hdConf)
3838

39-
val warehouseDir: String = conf.getOption("spark.sql.warehouse.dir").getOrElse("/user/hive/warehouse/")
39+
val warehouseDir: String = conf.getOption("spark.sql.warehouse.dir").getOrElse("/user/hive/warehouse")
4040
val partitionNums: Int = conf.getOption("spark.sql.shuffle.partitions").getOrElse("200").toInt
4141
val invalidRowPath = "/hyzs/invalidRows/"
4242
val mapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
@@ -110,21 +110,28 @@ object SparkUtils {
110110
def processNull(df: Dataset[Row]): Dataset[Row] = {
111111
df.na.fill(0.0)
112112
.na.fill("0.0")
113-
.na.replace("*", Map("" -> "0.0", "null" -> "0.0", -9999 -> 0.0))
113+
.na.replace("*", Map("" -> "0.0", "null" -> "0.0"))
114114
}
115115

116+
def processZeroValue(df: Dataset[Row]): Dataset[Row] = {
117+
df.na
118+
.replace("*", Map(0 -> 1, 0.0 -> 1))
119+
}
120+
121+
116122
def saveTable(df: Dataset[Row], tableName:String, dbName:String = "default"): Unit = {
117123

118124
spark.sql(s"drop table if exists $dbName.$tableName")
119125
var path = ""
120126
if(dbName != "default"){
121-
path = s"$warehouseDir$dbName.db/$tableName"
127+
path = s"$warehouseDir/$dbName.db/$tableName"
122128
}
123129
else{
124-
path = s"$warehouseDir$tableName"
130+
path = s"$warehouseDir/$tableName"
125131
}
126132
if(checkHDFileExist(path))dropHDFiles(path)
127133
df.write
134+
.option("path", path)
128135
.saveAsTable(s"$dbName.$tableName")
129136
}
130137

0 commit comments

Comments
 (0)