Skip to content

Commit 3d84543

Browse files
committed
Add code to experiement with param values
1 parent 3aa2712 commit 3d84543

File tree

1 file changed

+40
-26
lines changed

1 file changed

+40
-26
lines changed

part2/experiments.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,42 +11,26 @@
1111
from sklearn import svm
1212
from sklearn.linear_model import Perceptron
1313

14+
training_emails = []
15+
training_labels = []
16+
test_emails = []
17+
test_labels = []
1418

1519
def main():
16-
parsed_args = parse_command_line_arguments()
17-
20+
global training_emails, training_labels, test_emails, test_labels
21+
parsed_args = parse_command_line_arguments()
1822
data_path = parsed_args.path
1923
training_emails, training_labels, = load_dataset(data_path + "/training")
2024
test_emails, test_labels = load_dataset(data_path + "/test")
2125

22-
svm_classifier = Pipeline([
23-
('count_vectorizer', CountVectorizer(stop_words="english")),
24-
('tfidf_transformer', TfidfTransformer()),
25-
('svm_classifier', svm.SVC(gamma="scale",
26-
C=1,
27-
verbose=False))
28-
])
29-
30-
# Train the svm
31-
svm_classifier.fit(training_emails, training_labels)
32-
33-
num_correctly_classified = 0
34-
num_wrongly_classified = 0
35-
# Test it
36-
results = svm_classifier.predict(test_emails)
37-
for i in range(len(results)):
38-
if (results[i] == test_labels[i]):
39-
num_correctly_classified += 1
40-
else:
41-
num_wrongly_classified += 1
42-
43-
accuracy = str("{0:.3%}").format(numpy.mean(results == test_labels))
44-
print("Accuracy of SVM classifier was " + accuracy)
26+
test_svm_classifier(C=1, use_idf=True, kernel='linear')
27+
test_svm_classifier(C=2, use_idf=True, kernel='linear')
28+
test_svm_classifier(C=10, use_idf=True, kernel='linear')
4529

4630

4731
def parse_command_line_arguments():
4832
parser = argparse.ArgumentParser(
49-
description="Classifies emails as spam or ham")
33+
description="Tests two types of email spam classifiers")
5034
parser.add_argument(
5135
'path',
5236
help="The path to the two folders, \'test\' and \'training\'"
@@ -74,6 +58,36 @@ def load_dataset(data_path):
7458
return email_data, email_labels
7559

7660

61+
def test_svm_classifier(C, use_idf, kernel):
62+
svm_classifier = Pipeline([
63+
('count_vectorizer', CountVectorizer(stop_words="english")),
64+
('tfidf_transformer', TfidfTransformer(use_idf=use_idf)),
65+
('svm_classifier', svm.SVC(gamma="scale",
66+
C=C,
67+
class_weight=None,
68+
kernel=kernel
69+
verbose=False))
70+
])
71+
72+
print(svm_classifier.get_params()['svm_classifier'])
73+
74+
# Train the svm
75+
svm_classifier.fit(training_emails, training_labels)
76+
77+
# Test it
78+
num_correctly_classified = 0
79+
num_wrongly_classified = 0
80+
results = svm_classifier.predict(test_emails)
81+
for i in range(len(results)):
82+
if (results[i] == test_labels[i]):
83+
num_correctly_classified += 1
84+
else:
85+
num_wrongly_classified += 1
86+
87+
accuracy = str("{0:.3%}").format(numpy.mean(results == test_labels))
88+
print("Accuracy of SVM classifier was " + accuracy)
89+
90+
7791
if __name__ == "__main__":
7892
main()
7993

0 commit comments

Comments
 (0)