|
11 | 11 | from sklearn import svm |
12 | 12 | from sklearn.linear_model import Perceptron |
13 | 13 |
|
| 14 | +training_emails = [] |
| 15 | +training_labels = [] |
| 16 | +test_emails = [] |
| 17 | +test_labels = [] |
14 | 18 |
|
15 | 19 | def main(): |
16 | | - parsed_args = parse_command_line_arguments() |
17 | | - |
| 20 | + global training_emails, training_labels, test_emails, test_labels |
| 21 | + parsed_args = parse_command_line_arguments() |
18 | 22 | data_path = parsed_args.path |
19 | 23 | training_emails, training_labels, = load_dataset(data_path + "/training") |
20 | 24 | test_emails, test_labels = load_dataset(data_path + "/test") |
21 | 25 |
|
22 | | - svm_classifier = Pipeline([ |
23 | | - ('count_vectorizer', CountVectorizer(stop_words="english")), |
24 | | - ('tfidf_transformer', TfidfTransformer()), |
25 | | - ('svm_classifier', svm.SVC(gamma="scale", |
26 | | - C=1, |
27 | | - verbose=False)) |
28 | | - ]) |
29 | | - |
30 | | - # Train the svm |
31 | | - svm_classifier.fit(training_emails, training_labels) |
32 | | - |
33 | | - num_correctly_classified = 0 |
34 | | - num_wrongly_classified = 0 |
35 | | - # Test it |
36 | | - results = svm_classifier.predict(test_emails) |
37 | | - for i in range(len(results)): |
38 | | - if (results[i] == test_labels[i]): |
39 | | - num_correctly_classified += 1 |
40 | | - else: |
41 | | - num_wrongly_classified += 1 |
42 | | - |
43 | | - accuracy = str("{0:.3%}").format(numpy.mean(results == test_labels)) |
44 | | - print("Accuracy of SVM classifier was " + accuracy) |
| 26 | + test_svm_classifier(C=1, use_idf=True, kernel='linear') |
| 27 | + test_svm_classifier(C=2, use_idf=True, kernel='linear') |
| 28 | + test_svm_classifier(C=10, use_idf=True, kernel='linear') |
45 | 29 |
|
46 | 30 |
|
47 | 31 | def parse_command_line_arguments(): |
48 | 32 | parser = argparse.ArgumentParser( |
49 | | - description="Classifies emails as spam or ham") |
| 33 | + description="Tests two types of email spam classifiers") |
50 | 34 | parser.add_argument( |
51 | 35 | 'path', |
52 | 36 | help="The path to the two folders, \'test\' and \'training\'" |
@@ -74,6 +58,36 @@ def load_dataset(data_path): |
74 | 58 | return email_data, email_labels |
75 | 59 |
|
76 | 60 |
|
| 61 | +def test_svm_classifier(C, use_idf, kernel): |
| 62 | + svm_classifier = Pipeline([ |
| 63 | + ('count_vectorizer', CountVectorizer(stop_words="english")), |
| 64 | + ('tfidf_transformer', TfidfTransformer(use_idf=use_idf)), |
| 65 | + ('svm_classifier', svm.SVC(gamma="scale", |
| 66 | + C=C, |
| 67 | + class_weight=None, |
| 68 | + kernel=kernel |
| 69 | + verbose=False)) |
| 70 | + ]) |
| 71 | + |
| 72 | + print(svm_classifier.get_params()['svm_classifier']) |
| 73 | + |
| 74 | + # Train the svm |
| 75 | + svm_classifier.fit(training_emails, training_labels) |
| 76 | + |
| 77 | + # Test it |
| 78 | + num_correctly_classified = 0 |
| 79 | + num_wrongly_classified = 0 |
| 80 | + results = svm_classifier.predict(test_emails) |
| 81 | + for i in range(len(results)): |
| 82 | + if (results[i] == test_labels[i]): |
| 83 | + num_correctly_classified += 1 |
| 84 | + else: |
| 85 | + num_wrongly_classified += 1 |
| 86 | + |
| 87 | + accuracy = str("{0:.3%}").format(numpy.mean(results == test_labels)) |
| 88 | + print("Accuracy of SVM classifier was " + accuracy) |
| 89 | + |
| 90 | + |
77 | 91 | if __name__ == "__main__": |
78 | 92 | main() |
79 | 93 |
|
|
0 commit comments