Skip to content

Commit 358360d

Browse files
committed
Add config file
1 parent 03fb435 commit 358360d

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

model_generation.cfg

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[io]
2-
post_file = POST_FILE
3-
tags_file = TAGS_FILE
4-
selected_tags_file = SELECTED_TAGS
2+
post_file = /Users/QiaoLiu1/Autotag/test/Questions.csv
3+
tags_file = /Users/QiaoLiu1/Autotag/test/Tags.csv
4+
selected_tags_file = /Users/QiaoLiu1/Autotag/test/top140Tags.csv
55

66
[spark]
7-
master = MASTER
7+
master = local[2]
88

model_generation.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@
55
import ConfigParser
66
import pandas as pd
77

8-
from pyspark import SparkContext, SparkConf, SparkSession
8+
from pyspark import SparkContext, SparkConf
9+
from pyspark.sql import SparkSession
910
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
11+
from pyspark.mllib.regression import LabeledPoint
12+
from pyspark.mllib.linalg import SparseVector
13+
from pyspark.ml.classification import NaiveBayes
14+
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
15+
1016

1117
logging.basicConfig()
1218
logger=logging.getLogger('model_generation')
@@ -26,41 +32,71 @@
2632
# Try to initialize a spark cluster with master, master can be local or mesos URL, which is configurable in config file
2733
try:
2834
logger.debug("Initializing Spark cluster")
29-
conf=SparkConf()
30-
conf.setAppName('model_generation').setMaster(master)
35+
conf=SparkConf().setAppName('model_generation').setMaster(master)
3136
sc=SparkContext(conf=conf)
32-
spark=SparkSession.builder.config(conf).getOrCreate()
3337
logger.debug("Created Spark cluster successfully")
3438
except:
3539
logger.error("Fail to initialize spark cluster")
3640

41+
try:
42+
spark=SparkSession.builder.config(conf=conf).getOrCreate()
43+
logger.debug("Initialized spark session successfully")
44+
except:
45+
logger.error("Fail to start spark session")
46+
3747
# Input the dataset
3848
try:
3949
logger.debug("Start to read the input dataset")
4050
posts_df=spark.read.csv(posts_file, header=True)
4151
tags_df=spark.read.csv(tags_file, header=True)
4252
selected_tags=pd.read_csv(selected_tags_file, header=None)
43-
tag_set=sc.broadcast(set(selected_tags[0]))
53+
local_tags_to_catId=dict(zip(selected_tags[0], list(selected_tags.index)))
54+
local_catId_to_tags=dict(zip(list(selected_tags.index), selected_tags[0]))
55+
tags_to_catId=sc.broadcast(local_tags_to_catId)
56+
catId_to_tags=sc.broadcast(local_catId_to_tags)
57+
tags_set=sc.broadcast(set(selected_tags[0]))
4458
logger.debug("Read in dataset successfully")
4559
except:
46-
logger.debug("Can't input dataset")
60+
logger.error("Can't input dataset")
4761

4862
# Join posts_df and tags_df together and prepare training dataset
49-
selected_tags_df=tags_df.filter(tags_df['Tag'] in tag_set)
63+
selected_tags_df=tags_df.filter(tags_df.Tag.isin(tags_set.value))
5064
tags_questions_df=posts_df.join(selected_tags_df, posts_df.Id==selected_tags_df.Id)
5165
training_df=tags_questions_df.select(['Tag', 'Body'])
5266

5367
# tokenize post texts and get term frequency and inverted document frequency
5468
tokenizer=Tokenizer(inputCol="Body", outputCol="Words")
5569
tokenized_words=tokenizer.transform(training_df)
56-
hashing_TF=HashingTF(inputCol="Words", outputCol="Features")
70+
hashing_TF=HashingTF(inputCol="Words", outputCol="Features", numFeatures=200)
5771
TFfeatures=hashing_TF.transform(tokenized_words)
5872

5973
idf=IDF(inputCol="Features", outputCol="IDF_features")
6074
idfModel=idf.fit(TFfeatures)
61-
features=idfModel.transform(TFfeatures)
75+
TFIDFfeatures=idfModel.transform(TFfeatures)
76+
77+
for feature in TFIDFfeatures.select("IDF_features", "Tag").take(3):
78+
logger.info(feature)
79+
80+
# Row(IDF_features=SparseVector(200, {7: 2.3773, 9: 2.1588, 32: 2.0067, 37: 1.7143, 49: 2.6727, 59: 2.9361, 114: 1.0654, 145: 2.9522, 167: 2.3751}), Tag=u'asp.net')
81+
# Trasfer data to be in labeled point format
82+
labeled_points=TFIDFfeatures.rdd.map(lambda row: LabeledPoint(label=tags_to_catId.value[row.Tag], features=SparseVector(row.IDF_features.size, row.IDF_features.indices, row.IDF_features.values)))
83+
training, test=labeled_points.randomSplit([0.7, 0.3], seed=0)
84+
85+
# Train Naive Bayes model
86+
print training.take(3)
87+
nb=NaiveBayes(smoothing=1.0, modelType="multinomial")#
88+
nb_model=nb.fit(training)
89+
90+
# Evaluation the model
91+
predictions=nb_model.transform(test)
92+
print predictions.take(10)
93+
# evaluator=MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
94+
# prediction_and_label = test.map(lambda point : (nb_model.predict(point.features), point.label))
95+
# accuracy = 1.0 * prediction_and_label.filter(lambda x: 1.0 if x[0] == x[1] else 0.0).count() / test.count()
96+
97+
98+
6299

63-
64100

65101

66102

0 commit comments

Comments
 (0)