Skip to content

Commit 7a15eeb

Browse files
committed
Renaming variables to make it clear where the code vectors are
1 parent 6b327d9 commit 7a15eeb

File tree

5 files changed

+44
-13
lines changed

5 files changed

+44
-13
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ python3
190190
The above python commands will result in the closest name to both "equals" and "to|lower", which is "equals|ignore|case".
191191
Note: In embeddings that were exported manually using the "--save_w2v" or "--save_t2v" flags, the input token and target words are saved using the symbol "|" as a subtokens delimiter ("*toLower*" is saved as: "*to|lower*"). In the embeddings that are available to download (which are the same as in the paper), the "|" symbol is not used, thus "*toLower*" is saved as "*tolower*".
192192

193+
### Exporting the code vectors for the given code examples
194+
The flag `--export_code_vectors` allows to export the code vectors for the given examples.
195+
196+
If used with the `--test <TEST_FILE>` flag,
197+
a file named `<TEST_FILE>.vectors` will be saved in the same directory as `<TEST_FILE>`.
198+
Each row in the saved file is the code vector of the code snipped in the corresponding row in `<TEST_FILE>`.
199+
200+
If used with the `--predict` flag, the code vector will be printed to console.
201+
202+
193203
## Extending to other languages
194204
In order to extend code2vec to work with other languages other than Java, a new extractor (similar to the [JavaExtractor](JavaExtractor))
195205
should be implemented, and be called by [preprocess.sh](preprocess.sh).

code2vec.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
help="save word (token) vectors in word2vec format")
2525
parser.add_argument('--save_t2v', dest='save_t2v', required=False,
2626
help="save target vectors in word2vec format")
27+
parser.add_argument('--export_code_vectors', action='store_true', required=False,
28+
help="export code vectors for the given examples")
2729
parser.add_argument('--release', action='store_true',
2830
help='if specified and loading a trained model, release the loaded model for a lower model '
2931
'size.')

common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_default_config(args):
2727
config.SAVE_PATH = args.save_path
2828
config.LOAD_PATH = args.load_path
2929
config.RELEASE = args.release
30+
config.EXPORT_CODE_VECTORS = args.export_code_vectors
3031
return config
3132

3233
def __init__(self):
@@ -48,7 +49,7 @@ def __init__(self):
4849
self.LOAD_PATH = ''
4950
self.MAX_TO_KEEP = 0
5051
self.RELEASE = False
51-
52+
self.EXPORT_CODE_VECTORS = False
5253

5354
class common:
5455
noSuchWord = "NoSuchWord"

interactive_predict.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,16 @@ def predict(self):
4040
except ValueError as e:
4141
print(e)
4242
continue
43-
results = self.model.predict(predict_lines)
43+
results, code_vectors = self.model.predict(predict_lines)
4444
prediction_results = common.parse_results(results, hash_to_string_dict, topk=SHOW_TOP_CONTEXTS)
45-
for method_prediction in prediction_results:
45+
for i, method_prediction in enumerate(prediction_results):
4646
print('Original name:\t' + method_prediction.original_name)
4747
for name_prob_pair in method_prediction.predictions:
4848
print('\t(%f) predicted: %s' % (name_prob_pair['probability'], name_prob_pair['name']))
4949
print('Attention:')
5050
for attention_obj in method_prediction.attention_paths:
5151
print('%f\tcontext: %s,%s,%s' % (
5252
attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2']))
53+
if self.config.EXPORT_CODE_VECTORS:
54+
print('Code vector:')
55+
print(' '.join(map(str, code_vectors[i])))

model.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, config):
2121

2222
self.eval_placeholder = None
2323
self.predict_placeholder = None
24-
self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op = None, None, None
24+
self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors = None, None, None, None
2525
self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op = None, None, None
2626

2727
if config.LOAD_PATH:
@@ -130,7 +130,7 @@ def evaluate(self):
130130
target_word_to_index=self.target_word_to_index,
131131
config=self.config, is_evaluating=True)
132132
self.eval_placeholder = self.eval_queue.get_input_placeholder()
133-
self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, _, _, _, _ = \
133+
self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, _, _, _, _, self.eval_code_vectors = \
134134
self.build_test_graph(self.eval_queue.get_filtered_batches())
135135
self.saver = tf.train.Saver()
136136

@@ -149,15 +149,17 @@ def evaluate(self):
149149
print('Done loading test data')
150150

151151
with open('log.txt', 'w') as output_file:
152+
if self.config.EXPORT_CODE_VECTORS:
153+
code_vectors_file = open(self.config.TEST_PATH + '.vectors', 'w')
152154
num_correct_predictions = np.zeros(self.topk)
153155
total_predictions = 0
154156
total_prediction_batches = 0
155157
true_positive, false_positive, false_negative = 0, 0, 0
156158
start_time = time.time()
157159

158160
for batch in common.split_to_batches(self.eval_data_lines, self.config.TEST_BATCH_SIZE):
159-
top_words, top_scores, original_names = self.sess.run(
160-
[self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op],
161+
top_words, top_scores, original_names, code_vectors = self.sess.run(
162+
[self.eval_top_words_op, self.eval_top_values_op, self.eval_original_names_op, self.eval_code_vectors],
161163
feed_dict={self.eval_placeholder: batch})
162164
top_words, original_names = common.binary_to_string_matrix(top_words), common.binary_to_string_matrix(
163165
original_names)
@@ -172,21 +174,29 @@ def evaluate(self):
172174

173175
total_predictions += len(original_names)
174176
total_prediction_batches += 1
177+
if self.config.EXPORT_CODE_VECTORS:
178+
self.write_code_vectors(code_vectors_file, code_vectors)
175179
if total_prediction_batches % self.num_batches_to_log == 0:
176180
elapsed = time.time() - start_time
177181
# start_time = time.time()
178182
self.trace_evaluation(output_file, num_correct_predictions, total_predictions, elapsed, len(self.eval_data_lines))
179183

180184
print('Done testing, epoch reached')
181185
output_file.write(str(num_correct_predictions / total_predictions) + '\n')
182-
186+
if self.config.EXPORT_CODE_VECTORS:
187+
code_vectors_file.close()
188+
183189
elapsed = int(time.time() - eval_start_time)
184190
precision, recall, f1 = self.calculate_results(true_positive, false_positive, false_negative)
185191
print("Evaluation time: %sH:%sM:%sS" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))
186192
del self.eval_data_lines
187193
self.eval_data_lines = None
188194
return num_correct_predictions / total_predictions, precision, recall, f1
189195

196+
def write_code_vectors(self, file, code_vectors):
197+
for vec in code_vectors:
198+
file.write(' '.join(map(str, vec)) + '\n')
199+
190200
def update_per_subtoken_statistics(self, results, true_positive, false_positive, false_negative):
191201
for original_name, top_words in results:
192202
prediction = common.filter_impossible_names(top_words)[0]
@@ -342,7 +352,7 @@ def build_test_graph(self, input_tensors, normalize_scores=False):
342352
if normalize_scores:
343353
top_scores = tf.nn.softmax(top_scores)
344354

345-
return top_words, top_scores, original_words, attention_weights, source_string, path_string, path_target_string
355+
return top_words, top_scores, original_words, attention_weights, source_string, path_string, path_target_string, code_vectors
346356

347357
def predict(self, predict_data_lines):
348358
if self.predict_queue is None:
@@ -352,19 +362,20 @@ def predict(self, predict_data_lines):
352362
config=self.config, is_evaluating=True)
353363
self.predict_placeholder = self.predict_queue.get_input_placeholder()
354364
self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op, \
355-
self.attention_weights_op, self.predict_source_string, self.predict_path_string, self.predict_path_target_string = \
365+
self.attention_weights_op, self.predict_source_string, self.predict_path_string, self.predict_path_target_string, self.predict_code_vectors = \
356366
self.build_test_graph(self.predict_queue.get_filtered_batches(), normalize_scores=True)
357367

358368
self.initialize_session_variables(self.sess)
359369
self.saver = tf.train.Saver()
360370
self.load_model(self.sess)
361371

372+
code_vectors = []
362373
results = []
363374
for batch in common.split_to_batches(predict_data_lines, 1):
364-
top_words, top_scores, original_names, attention_weights, source_strings, path_strings, target_strings = self.sess.run(
375+
top_words, top_scores, original_names, attention_weights, source_strings, path_strings, target_strings, batch_code_vectors = self.sess.run(
365376
[self.predict_top_words_op, self.predict_top_values_op, self.predict_original_names_op,
366377
self.attention_weights_op, self.predict_source_string, self.predict_path_string,
367-
self.predict_path_target_string],
378+
self.predict_path_target_string, self.predict_code_vectors],
368379
feed_dict={self.predict_placeholder: batch})
369380
top_words, original_names = common.binary_to_string_matrix(top_words), common.binary_to_string_matrix(
370381
original_names)
@@ -373,7 +384,11 @@ def predict(self, predict_data_lines):
373384
attention_weights)
374385
original_names = [w for l in original_names for w in l]
375386
results.append((original_names[0], top_words[0], top_scores[0], attention_per_path))
376-
return results
387+
if self.config.EXPORT_CODE_VECTORS:
388+
code_vectors.append(batch_code_vectors)
389+
if len(code_vectors) > 0:
390+
code_vectors = np.vstack(code_vectors)
391+
return results, code_vectors
377392

378393
def get_attention_per_path(self, source_strings, path_strings, target_strings, attention_weights):
379394
attention_weights = np.squeeze(attention_weights) # (max_contexts, )

0 commit comments

Comments
 (0)