22import numpy as np
33import pandas as pd
44import os
5+ import sys
56
67from sklearn .model_selection import cross_val_score
78from sklearn .model_selection import GridSearchCV , KFold , StratifiedKFold
@@ -16,14 +17,14 @@ def evaluate_embedding(embeddings, labels):
1617
1718 labels = preprocessing .LabelEncoder ().fit_transform (labels )
1819 x , y = np .array (embeddings ), np .array (labels )
20+ print (x .shape , y .shape )
1921
2022 kf = StratifiedKFold (n_splits = 10 , shuffle = True , random_state = None )
2123 accuracies = []
2224 for train_index , test_index in kf .split (x , y ):
2325
2426 x_train , x_test = x [train_index ], x [test_index ]
2527 y_train , y_test = y [train_index ], y [test_index ]
26- # x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1)
2728 search = True
2829 if search :
2930 params = {'C' :[0.001 , 0.01 ,0.1 ,1 ,10 ,100 ,1000 ]}
@@ -40,11 +41,10 @@ def evaluate_embedding(embeddings, labels):
4041
4142 x_train , x_test = x [train_index ], x [test_index ]
4243 y_train , y_test = y [train_index ], y [test_index ]
43- # x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1)
4444 search = True
4545 if search :
46- classifier = GridSearchCV (LinearSVC (), params , cv = 5 , scoring = 'accuracy' , verbose = 0 )
4746 params = {'C' :[0.001 , 0.01 ,0.1 ,1 ,10 ,100 ,1000 ]}
47+ classifier = GridSearchCV (LinearSVC (), params , cv = 5 , scoring = 'accuracy' , verbose = 0 )
4848 else :
4949 classifier = LinearSVC (C = 10 )
5050 classifier .fit (x_train , y_train )
@@ -69,16 +69,12 @@ def evaluate_embedding(embeddings, labels):
6969if __name__ == '__main__' :
7070 # x, y = get_mutag()
7171 emb = []
72- with open ('data/results/output .txt' , 'r' ) as f :
72+ with open ('data/results/{}_output .txt' . format ( sys . argv [ 1 ]) , 'r' ) as f :
7373 for line in f :
7474 emb .append (list (map (float , [x for x in line .strip ().split ()])))
7575
76- import sys
77- graphs = read_graphfile ( '../data' , sys . argv [ 1 ])
78- y = [graph . graph [ 'label' ] for graph in graphs ]
76+ with open ( '../data/{}_label.txt' . format ( sys . argv [ 1 ]), 'r' ) as f :
77+ y = f . readlines ()
78+ y = [int ( x . strip ()) for x in y ]
7979
8080 evaluate_embedding (emb , y )
81- # import sys
82- # preprocess(sys.argv[1])
83-
84-
0 commit comments