Skip to content

Commit a26828e

Browse files
committed
Add accuracy printout and use pipelined flow
1 parent a3be71b commit a26828e

File tree

1 file changed

+29
-50
lines changed

1 file changed

+29
-50
lines changed

part2/experiments.py

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,44 @@
33
import argparse
44
import os
55
import glob
6+
import numpy
67
from time import time
7-
88
from sklearn.feature_extraction.text import CountVectorizer
99
from sklearn.feature_extraction.text import TfidfTransformer
10+
from sklearn.pipeline import Pipeline
1011
from sklearn import svm
1112
from sklearn.linear_model import Perceptron
1213

1314

1415
def main():
1516
parsed_args = parse_command_line_arguments()
17+
1618
data_path = parsed_args.path
17-
18-
# We need the tfidf_transformer and counts_vect here to pass the the test email loader.
19-
# There are other ways to do this, but this is the most convenient.
20-
# It's important to supply this tranformer because we need to maintain the vocabulary between test and training
21-
training_emails, training_labels, tfidf_transformer, counts_vect = load_training_dataset(data_path + "/training")
22-
test_emails, test_labels = load_test_dataset(data_path + "/test", tfidf_transformer, counts_vect)
23-
24-
svm_classifier = train_svm_classifier(training_emails, training_labels)
25-
19+
training_emails, training_labels, = load_dataset(data_path + "/training")
20+
test_emails, test_labels = load_dataset(data_path + "/test")
21+
22+
svm_classifier = Pipeline([
23+
('count_vectorizer', CountVectorizer(stop_words="english")),
24+
('tfidf_transformer', TfidfTransformer()),
25+
('svm_classifier', svm.SVC(gamma="scale",
26+
C=1,
27+
verbose=False))
28+
])
29+
30+
# Train the svm
31+
svm_classifier.fit(training_emails, training_labels)
32+
33+
num_correctly_classified = 0
34+
num_wrongly_classified = 0
35+
# Test it
2636
results = svm_classifier.predict(test_emails)
27-
print(results)
37+
for i in range(len(results)):
38+
if (results[i] == test_labels[i]):
39+
num_correctly_classified += 1
40+
else:
41+
num_wrongly_classified += 1
42+
43+
print("Accuracy of SVM classifier was " + str("{0:.3%}").format(numpy.mean(results == test_labels)))
2844

2945

3046
def parse_command_line_arguments():
@@ -37,7 +53,7 @@ def parse_command_line_arguments():
3753
return parser.parse_args()
3854

3955

40-
def load_training_dataset(data_path):
56+
def load_dataset(data_path):
4157
print("Loading training dataset at " + data_path)
4258

4359
email_data = []
@@ -52,46 +68,9 @@ def load_training_dataset(data_path):
5268
with open(file, 'r') as email_file:
5369
email_data.append(email_file.read())
5470

55-
count_vectorizer = CountVectorizer(stop_words='english')
56-
tfidf_transformer = TfidfTransformer(use_idf=False)
57-
email_feature_data = tfidf_transformer.fit_transform(count_vectorizer.fit_transform(email_data))
58-
59-
print(email_feature_data)
6071
print("Loaded " + str(len(email_labels)) + " training emails")
6172

62-
return email_feature_data, email_labels, tfidf_transformer, count_vectorizer
63-
64-
65-
def load_test_dataset(data_path, tfidf_transformer, counts_vect):
66-
print("Loading test dataset at " + data_path)
67-
68-
email_data = []
69-
email_labels = []
70-
71-
os.chdir(data_path)
72-
for file in os.listdir():
73-
if (file.startswith("sp")):
74-
email_labels.append("spam")
75-
else:
76-
email_labels.append("ham")
77-
with open(file, 'r') as email_file:
78-
email_data.append(email_file.read())
79-
80-
email_feature_data = tfidf_transformer.transform(counts_vect.transform(email_data))
81-
82-
print(email_feature_data)
83-
print("Loaded " + str(len(email_labels)) + " test emails")
84-
85-
return email_feature_data, email_labels
86-
87-
88-
def train_svm_classifier(emails, labels):
89-
classifier = svm.SVC(
90-
gamma='scale',
91-
C=1,
92-
verbose=True)
93-
classifier.fit(emails, labels)
94-
return classifier
73+
return email_data, email_labels
9574

9675

9776
if __name__ == "__main__":

0 commit comments

Comments
 (0)