Skip to content

Commit 9cd91e5

Browse files
committed
update model prediction
1 parent a0392ab commit 9cd91e5

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/main/scala/com/hyzs/spark/sql/NewDataProcess.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ object NewDataProcess {
164164
}
165165
val aggTable = getColumnAgg(trans, keyColumn, "purchase_amount")
166166
ids = ids.join(aggTable, Seq(keyColumn), "left")
167+
ids = addColumnsPrefix(ids, "historical", Array(keyColumn))
167168
saveTable(ids, "historical_transactions_processed", "merchant")
168169
}
169170

@@ -174,15 +175,19 @@ object NewDataProcess {
174175
var testTable = spark.table("merchant.test").withColumn("target", lit(0))
175176
.select("card_id", "target", "feature_1", "feature_2", "feature_3")
176177
val transTable = spark.table("merchant.new_transactions_processed")
178+
val hisTable = spark.table("merchant.historical_transactions_processed")
177179

178180
trainTable = trainTable.join(transTable, Seq(keyColumn), "left")
181+
.join(hisTable, Seq(keyColumn), "left")
179182
saveTable(trainTable, "train_result", "merchant")
180183
testTable = testTable.join(transTable, Seq(keyColumn), "left")
184+
.join(hisTable, Seq(keyColumn), "left")
181185
saveTable(testTable, "test_result", "merchant")
182186
}
183187

184188
def main(args: Array[String]): Unit = {
185189
transProcess()
190+
hisProcess()
186191
merchantProcess()
187192
}
188193

src/main/scala/com/hyzs/spark/utils/SparkUtils.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import org.apache.spark.sql.{Dataset, Row, SQLContext, SparkSession}
1515
import org.apache.spark.{SparkConf, SparkContext}
1616
import org.apache.spark.sql.types.{StringType, StructField, StructType}
1717
import org.apache.spark.util.SizeEstimator
18-
18+
import org.apache.spark.sql.functions._
1919
import scala.util.Try
2020

2121

@@ -206,14 +206,14 @@ object SparkUtils {
206206
SizeEstimator.estimate(rdd)
207207
}
208208

209-
// for test
210-
/* case class Person(name : String , age : Int)
211-
212-
def createDatasetTest(): Unit ={
213-
val personRDD = sc.makeRDD(Seq(Person("A",10),Person("B",20)))
214-
val personDF = spark.createDataFrame(personRDD)
215-
val ds:Dataset[Person] = personDF.as[Person]
216-
}*/
209+
def addColumnsPrefix(dataSet:Dataset[Row],
210+
colPrefix:String,
211+
ignoreCols:Array[String]): Dataset[Row] = {
212+
dataSet.select(
213+
dataSet.columns.map( fieldName =>
214+
if(ignoreCols.contains(fieldName)) col(fieldName)
215+
else col(fieldName).as(s"${colPrefix}__$fieldName") ): _*)
216+
}
217217

218218
}
219219

0 commit comments

Comments
 (0)