Skip to content

Commit dba0de7

Browse files
committed
update data processing
1 parent 9c83911 commit dba0de7

File tree

3 files changed

+92
-32
lines changed

3 files changed

+92
-32
lines changed

src/main/resources/param_tuning.md

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,26 @@
77

88
## xgboost
99

10-
|object|booster|eval_metric|max_depth|eta|gamma|alpha|subsample|num_round|**train_result**|**valid_result**|**test_result**|
11-
|---|---|---|---|---|---|----|---|----|---|---|---|
12-
|reg:linear|gbtree|rmse|4|0.01|0.5|0.007|0.7|800|3.99 ~ 3.81|3.740|3.884|
13-
|reg:linear|gbtree|rmse|5|0.01|0.3|0.001|0.8|800|3.99 ~ 3.76|3.742|na|
14-
|reg:linear|gbtree|rmse|5|0.01|0.5|na|0.5|800|3.99 ~ 3.76|3.74|3.885|
15-
|reg:linear|gbtree|rmse|5|0.1|0.5|0.007|0.7|800|3.99 ~ 3.40|3.81|3.952|
16-
|reg:linear|gbtree|rmse|5|0.1|0.5|0.01|0.5|400|3.97 ~ 3.57|3.789|na|
17-
|reg:linear|gbtree|rmse|7|0.01|0.5|na|0.7|800|3.91 ~ 3.50|3.87|3.886|
18-
|reg:linear|gbtree|rmse|6|0.01|0.3|na|0.8|800|3.99 ~ 3.69|3.745|na|
10+
### object=reg:linear, booster=gbtree, eval_metric=rmse
11+
12+
|max_depth|eta|gamma|alpha|subsample|min_child_weight|num_round|**train_result**|**valid_result**|**test_result**|
13+
|---|---|---|----|---|----|---|---|---|---|
14+
|4|0.01|0.5|0.007|0.7||800|3.99 ~ 3.81|3.740|3.884|
15+
|4|0.01|0.1| |0.7| 0.5|400|3.99 ~ 3.83|3.742|3.884|
16+
|4|0.005| 0.5|na|0.7|1|400|3.99 ~ 3.85|3.749|3.886|
17+
|5|0.01|0.5|na|0.5||800|3.99 ~ 3.76|3.74|3.885|
18+
|5|0.1|0.5|0.007|0.7||800|3.99 ~ 3.40|3.81|3.952|
19+
|5|0.1|0.5|0.01|0.5||400|3.97 ~ 3.57|3.789|na|
20+
|5|0.1|0.7|0.01|0.7|0.8|400|3.97 ~ 3.57|3.788|3.937|
21+
|5|0.01|0.7|0.01|0.7|0.8|400|3.99 ~ 3.80|3.741|3.884|
22+
|5|0.03|0.5|0.05|0.7| |400|3.99 ~ 3.732 |3.746 | |
23+
|5|0.03|0.5|0.03|0.7| |500| | | |
24+
|5|0.1|1|0.05|0.5|1|400|3.97 ~ 3.57|3.78|na|
25+
|5|0.5|1|0.05|0.5|1|400|3.89 ~ 3.30|4.19|na|
26+
|5|0.05|1|0.01|0.7|1|400|3.98 ~ 3.67 |3.754 |na|
27+
|5|0.03|1.5|0.05|0.7|1.5|400| 3.99 ~ 3.73|3.745|3.889|
28+
|6|0.01| na | 0.3|na |0.8|800|3.99 ~ 3.69|3.745|na|
29+
|6|0.02| 1.5 | 0.1|0.7|1|400|3.99 ~ 3.69|3.745 |na|
30+
|6|0.01| 1.0 | 0.1|0.7|1|400|3.99 ~ 3.76 |3.742 |na|
31+
|7|0.01|0.5|na|0.7||800|3.91 ~ 3.50|3.87|3.886|
32+

src/main/scala/com/hyzs/spark/ml/ModelPrediction.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,17 @@ object ModelPrediction {
205205
} else if (goal == "Regression"){
206206
val xgbParam = Map(
207207
"max_depth" -> 5,
208-
"alpha" -> 0.01f,
209-
"subsample" -> 0.5,
208+
"alpha" -> 0.03f,
209+
"subsample" -> 0.7,
210210
//"colsample_bytree" -> 0.7,
211+
//"min_child_weight" -> 0.5,
211212
"objective" -> "reg:linear",
212213
//"top_k" -> "13",
213214
"booster" -> "gbtree",
214-
"eta" -> 0.1f,
215+
"eta" -> 0.03f,
215216
"gamma" -> 0.5,
216217
"eval_metric" -> "rmse",
217-
"num_round" -> 400)
218+
"num_round" -> 500)
218219
val xgbReg = new XGBoostRegressor(xgbParam)
219220
.setFeaturesCol("features").setLabelCol("label")
220221

src/main/scala/com/hyzs/spark/sql/NewDataProcess.scala

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ object NewDataProcess {
131131
}
132132

133133
def getColumnAgg(transData:Dataset[Row], keyColumn:String, aggColumn:String): Dataset[Row] = {
134-
transData.groupBy(keyColumn)
134+
transData.select(col(keyColumn), col(aggColumn).cast("double"))
135+
.groupBy(keyColumn)
135136
.agg(sum(aggColumn).as(s"sum_$aggColumn"),
136137
avg(aggColumn).as(s"avg_$aggColumn")
137138
)
@@ -140,32 +141,76 @@ object NewDataProcess {
140141
def transProcess(): Unit = {
141142
val keyColumn = "card_id"
142143
val newTrans = spark.table("merchant.new_merchant_transactions")
143-
var ids = newTrans.select(keyColumn).distinct()
144+
val ids = newTrans.select(keyColumn).distinct()
145+
ids.cache()
146+
val dropCols = Array("merchant_category_id", "subsector_id", "category_1", "city_id", "state_id", "category_2")
147+
val merchants = spark.table("merchant.merchants").dropDuplicates(Array("merchant_id"))
148+
.drop(dropCols: _*)
149+
val tmpTrans = newTrans.join(merchants, Seq("merchant_id"), "left")
150+
tmpTrans.cache()
151+
152+
var tmpModeRes = ids
144153
val processMode = Seq("city_id", "category_1", "installments", "category_3",
145-
"merchant_category_id", "category_2", "state_id", "subsector_id")
154+
"merchant_category_id", "category_2", "state_id", "subsector_id",
155+
"merchant_group_id", "category_4")
146156
for(colName <- processMode){
147-
val modeTmpTable = getColumnMode(newTrans, keyColumn, colName)
148-
ids = ids.join(modeTmpTable, Seq(keyColumn), "left")
157+
val modeTmpTable = getColumnMode(tmpTrans, keyColumn, colName)
158+
tmpModeRes = tmpModeRes.join(modeTmpTable, Seq(keyColumn), "left")
149159
}
150-
val aggTable = getColumnAgg(newTrans, keyColumn, "purchase_amount")
151-
ids = ids.join(aggTable, Seq(keyColumn), "left")
152-
saveTable(ids, "new_transactions_processed", "merchant")
160+
saveTable(tmpModeRes, "tmp_mode_res", "merchant")
161+
162+
var tmpAggRes = ids
163+
val aggCols = Array("purchase_amount", "numerical_1", "numerical_2", "avg_sales_lag3", "avg_purchases_lag3",
164+
"avg_sales_lag6", "avg_purchases_lag6", "avg_sales_lag12", "avg_purchases_lag12")
165+
for(aggCol <- aggCols){
166+
val aggTable = getColumnAgg(tmpTrans, keyColumn, aggCol)
167+
tmpAggRes = tmpAggRes.join(aggTable, Seq(keyColumn), "left")
168+
}
169+
saveTable(tmpAggRes, "tmp_agg_res", "merchant")
170+
171+
val modeRes = spark.table("merchant.tmp_mode_res")
172+
val aggRes = spark.table("merchant.tmp_agg_res")
173+
val tmpRes = modeRes.join(aggRes, Seq(keyColumn), "left")
174+
175+
saveTable(tmpRes, "new_transactions_processed", "merchant")
153176
}
154177

155178
def hisProcess(): Unit = {
156179
val keyColumn = "card_id"
157-
val trans = spark.table("merchant.historical_transactions")
158-
var ids = trans.select(keyColumn).distinct()
180+
val hisTrans = spark.table("merchant.historical_transactions")
181+
val ids = hisTrans.select(keyColumn).distinct()
182+
ids.cache()
183+
val dropCols = Array("merchant_category_id", "subsector_id", "category_1", "city_id", "state_id", "category_2")
184+
val merchants = spark.table("merchant.merchants").dropDuplicates(Array("merchant_id"))
185+
.drop(dropCols: _*)
186+
val tmpTrans = hisTrans.join(merchants, Seq("merchant_id"), "left")
187+
//tmpTrans.cache()
188+
189+
var tmpModeRes = ids
159190
val processMode = Seq("city_id", "category_1", "installments", "category_3",
160-
"merchant_category_id", "category_2", "state_id", "subsector_id")
191+
"merchant_category_id", "category_2", "state_id", "subsector_id",
192+
"merchant_group_id", "category_4")
161193
for(colName <- processMode){
162-
val modeTmpTable = getColumnMode(trans, keyColumn, colName)
163-
ids = ids.join(modeTmpTable, Seq(keyColumn), "left")
194+
val modeTmpTable = getColumnMode(tmpTrans, keyColumn, colName)
195+
tmpModeRes = tmpModeRes.join(modeTmpTable, Seq(keyColumn), "left")
164196
}
165-
val aggTable = getColumnAgg(trans, keyColumn, "purchase_amount")
166-
ids = ids.join(aggTable, Seq(keyColumn), "left")
167-
ids = addColumnsPrefix(ids, "historical", Array(keyColumn))
168-
saveTable(ids, "historical_transactions_processed", "merchant")
197+
saveTable(tmpModeRes, "tmp_mode_res", "merchant")
198+
199+
var tmpAggRes = ids
200+
val aggCols = Array("purchase_amount", "numerical_1", "numerical_2")
201+
//"avg_sales_lag3", "avg_purchases_lag3", "avg_sales_lag6", "avg_purchases_lag6", "avg_sales_lag12", "avg_purchases_lag12")
202+
for(aggCol <- aggCols){
203+
val aggTable = getColumnAgg(tmpTrans, keyColumn, aggCol)
204+
tmpAggRes = tmpAggRes.join(aggTable, Seq(keyColumn), "left")
205+
}
206+
saveTable(tmpAggRes, "tmp_agg_res", "merchant")
207+
208+
val modeRes = spark.table("merchant.tmp_mode_res")
209+
val aggRes = spark.table("merchant.tmp_agg_res")
210+
val tmpRes = addColumnsPrefix(modeRes.join(aggRes, Seq(keyColumn), "left"),
211+
"historical", Array(keyColumn))
212+
213+
saveTable(tmpRes, "historical_transactions_processed", "merchant")
169214
}
170215

171216
def merchantProcess(): Unit = {
@@ -186,8 +231,8 @@ object NewDataProcess {
186231
}
187232

188233
def main(args: Array[String]): Unit = {
189-
transProcess()
190-
hisProcess()
234+
//transProcess()
235+
//hisProcess()
191236
merchantProcess()
192237
}
193238

0 commit comments

Comments
 (0)