33import argparse
44import os
55import glob
6+ import numpy
67from time import time
7-
88from sklearn .feature_extraction .text import CountVectorizer
99from sklearn .feature_extraction .text import TfidfTransformer
10+ from sklearn .pipeline import Pipeline
1011from sklearn import svm
1112from sklearn .linear_model import Perceptron
1213
1314
1415def main ():
1516 parsed_args = parse_command_line_arguments ()
17+
1618 data_path = parsed_args .path
17-
18- # We need the tfidf_transformer and counts_vect here to pass the the test email loader.
19- # There are other ways to do this, but this is the most convenient.
20- # It's important to supply this tranformer because we need to maintain the vocabulary between test and training
21- training_emails , training_labels , tfidf_transformer , counts_vect = load_training_dataset (data_path + "/training" )
22- test_emails , test_labels = load_test_dataset (data_path + "/test" , tfidf_transformer , counts_vect )
23-
24- svm_classifier = train_svm_classifier (training_emails , training_labels )
25-
19+ training_emails , training_labels , = load_dataset (data_path + "/training" )
20+ test_emails , test_labels = load_dataset (data_path + "/test" )
21+
22+ svm_classifier = Pipeline ([
23+ ('count_vectorizer' , CountVectorizer (stop_words = "english" )),
24+ ('tfidf_transformer' , TfidfTransformer ()),
25+ ('svm_classifier' , svm .SVC (gamma = "scale" ,
26+ C = 1 ,
27+ verbose = False ))
28+ ])
29+
30+ # Train the svm
31+ svm_classifier .fit (training_emails , training_labels )
32+
33+ num_correctly_classified = 0
34+ num_wrongly_classified = 0
35+ # Test it
2636 results = svm_classifier .predict (test_emails )
27- print (results )
37+ for i in range (len (results )):
38+ if (results [i ] == test_labels [i ]):
39+ num_correctly_classified += 1
40+ else :
41+ num_wrongly_classified += 1
42+
43+ print ("Accuracy of SVM classifier was " + str ("{0:.3%}" ).format (numpy .mean (results == test_labels )))
2844
2945
3046def parse_command_line_arguments ():
@@ -37,7 +53,7 @@ def parse_command_line_arguments():
3753 return parser .parse_args ()
3854
3955
40- def load_training_dataset (data_path ):
56+ def load_dataset (data_path ):
4157 print ("Loading training dataset at " + data_path )
4258
4359 email_data = []
@@ -52,46 +68,9 @@ def load_training_dataset(data_path):
5268 with open (file , 'r' ) as email_file :
5369 email_data .append (email_file .read ())
5470
55- count_vectorizer = CountVectorizer (stop_words = 'english' )
56- tfidf_transformer = TfidfTransformer (use_idf = False )
57- email_feature_data = tfidf_transformer .fit_transform (count_vectorizer .fit_transform (email_data ))
58-
59- print (email_feature_data )
6071 print ("Loaded " + str (len (email_labels )) + " training emails" )
6172
62- return email_feature_data , email_labels , tfidf_transformer , count_vectorizer
63-
64-
65- def load_test_dataset (data_path , tfidf_transformer , counts_vect ):
66- print ("Loading test dataset at " + data_path )
67-
68- email_data = []
69- email_labels = []
70-
71- os .chdir (data_path )
72- for file in os .listdir ():
73- if (file .startswith ("sp" )):
74- email_labels .append ("spam" )
75- else :
76- email_labels .append ("ham" )
77- with open (file , 'r' ) as email_file :
78- email_data .append (email_file .read ())
79-
80- email_feature_data = tfidf_transformer .transform (counts_vect .transform (email_data ))
81-
82- print (email_feature_data )
83- print ("Loaded " + str (len (email_labels )) + " test emails" )
84-
85- return email_feature_data , email_labels
86-
87-
88- def train_svm_classifier (emails , labels ):
89- classifier = svm .SVC (
90- gamma = 'scale' ,
91- C = 1 ,
92- verbose = True )
93- classifier .fit (emails , labels )
94- return classifier
73+ return email_data , email_labels
9574
9675
9776if __name__ == "__main__" :
0 commit comments