|
5 | 5 | import ConfigParser |
6 | 6 | import pandas as pd |
7 | 7 |
|
8 | | -from pyspark import SparkContext, SparkConf, SparkSession |
| 8 | +from pyspark import SparkContext, SparkConf |
| 9 | +from pyspark.sql import SparkSession |
9 | 10 | from pyspark.ml.feature import HashingTF, IDF, Tokenizer |
| 11 | +from pyspark.mllib.regression import LabeledPoint |
| 12 | +from pyspark.mllib.linalg import SparseVector |
| 13 | +from pyspark.ml.classification import NaiveBayes |
| 14 | +from pyspark.ml.evaluation import MulticlassClassificationEvaluator |
| 15 | + |
10 | 16 |
|
11 | 17 | logging.basicConfig() |
12 | 18 | logger=logging.getLogger('model_generation') |
|
26 | 32 | # Try to initialize a spark cluster with master, master can be local or mesos URL, which is configurable in config file |
27 | 33 | try: |
28 | 34 | logger.debug("Initializing Spark cluster") |
29 | | -conf=SparkConf() |
30 | | -conf.setAppName('model_generation').setMaster(master) |
| 35 | +conf=SparkConf().setAppName('model_generation').setMaster(master) |
31 | 36 | sc=SparkContext(conf=conf) |
32 | | -spark=SparkSession.builder.config(conf).getOrCreate() |
33 | 37 | logger.debug("Created Spark cluster successfully") |
34 | 38 | except: |
35 | 39 | logger.error("Fail to initialize spark cluster") |
36 | 40 |
|
| 41 | +try: |
| 42 | +spark=SparkSession.builder.config(conf=conf).getOrCreate() |
| 43 | +logger.debug("Initialized spark session successfully") |
| 44 | +except: |
| 45 | +logger.error("Fail to start spark session") |
| 46 | + |
37 | 47 | # Input the dataset |
38 | 48 | try: |
39 | 49 | logger.debug("Start to read the input dataset") |
40 | 50 | posts_df=spark.read.csv(posts_file, header=True) |
41 | 51 | tags_df=spark.read.csv(tags_file, header=True) |
42 | 52 | selected_tags=pd.read_csv(selected_tags_file, header=None) |
43 | | -tag_set=sc.broadcast(set(selected_tags[0])) |
| 53 | +local_tags_to_catId=dict(zip(selected_tags[0], list(selected_tags.index))) |
| 54 | +local_catId_to_tags=dict(zip(list(selected_tags.index), selected_tags[0])) |
| 55 | +tags_to_catId=sc.broadcast(local_tags_to_catId) |
| 56 | +catId_to_tags=sc.broadcast(local_catId_to_tags) |
| 57 | +tags_set=sc.broadcast(set(selected_tags[0])) |
44 | 58 | logger.debug("Read in dataset successfully") |
45 | 59 | except: |
46 | | -logger.debug("Can't input dataset") |
| 60 | +logger.error("Can't input dataset") |
47 | 61 |
|
48 | 62 | # Join posts_df and tags_df together and prepare training dataset |
49 | | -selected_tags_df=tags_df.filter(tags_df['Tag'] in tag_set) |
| 63 | +selected_tags_df=tags_df.filter(tags_df.Tag.isin(tags_set.value)) |
50 | 64 | tags_questions_df=posts_df.join(selected_tags_df, posts_df.Id==selected_tags_df.Id) |
51 | 65 | training_df=tags_questions_df.select(['Tag', 'Body']) |
52 | 66 |
|
53 | 67 | # tokenize post texts and get term frequency and inverted document frequency |
54 | 68 | tokenizer=Tokenizer(inputCol="Body", outputCol="Words") |
55 | 69 | tokenized_words=tokenizer.transform(training_df) |
56 | | -hashing_TF=HashingTF(inputCol="Words", outputCol="Features") |
| 70 | +hashing_TF=HashingTF(inputCol="Words", outputCol="Features", numFeatures=200) |
57 | 71 | TFfeatures=hashing_TF.transform(tokenized_words) |
58 | 72 |
|
59 | 73 | idf=IDF(inputCol="Features", outputCol="IDF_features") |
60 | 74 | idfModel=idf.fit(TFfeatures) |
61 | | -features=idfModel.transform(TFfeatures) |
| 75 | +TFIDFfeatures=idfModel.transform(TFfeatures) |
| 76 | + |
| 77 | +for feature in TFIDFfeatures.select("IDF_features", "Tag").take(3): |
| 78 | +logger.info(feature) |
| 79 | + |
| 80 | +# Row(IDF_features=SparseVector(200, {7: 2.3773, 9: 2.1588, 32: 2.0067, 37: 1.7143, 49: 2.6727, 59: 2.9361, 114: 1.0654, 145: 2.9522, 167: 2.3751}), Tag=u'asp.net') |
| 81 | +# Trasfer data to be in labeled point format |
| 82 | +labeled_points=TFIDFfeatures.rdd.map(lambda row: LabeledPoint(label=tags_to_catId.value[row.Tag], features=SparseVector(row.IDF_features.size, row.IDF_features.indices, row.IDF_features.values))) |
| 83 | +training, test=labeled_points.randomSplit([0.7, 0.3], seed=0) |
| 84 | + |
| 85 | +# Train Naive Bayes model |
| 86 | +print training.take(3) |
| 87 | +nb=NaiveBayes(smoothing=1.0, modelType="multinomial")# |
| 88 | +nb_model=nb.fit(training) |
| 89 | + |
| 90 | + # Evaluation the model |
| 91 | +predictions=nb_model.transform(test) |
| 92 | +print predictions.take(10) |
| 93 | +# evaluator=MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction") |
| 94 | +# prediction_and_label = test.map(lambda point : (nb_model.predict(point.features), point.label)) |
| 95 | +# accuracy = 1.0 * prediction_and_label.filter(lambda x: 1.0 if x[0] == x[1] else 0.0).count() / test.count() |
| 96 | + |
| 97 | + |
| 98 | + |
62 | 99 |
|
63 | | - |
64 | 100 |
|
65 | 101 |
|
66 | 102 |
|
|
0 commit comments