Skip to content

Commit 01d1731

Browse files
authored
Merge pull request #27 from eladn/master
Add tf.keras model implementation
2 parents dcdc3c3 + 64bebf7 commit 01d1731

20 files changed

+2460
-808
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
*.class
22
*.lst
3+
**/models/**
4+
**/data/**
5+
**/.idea/**
6+
*.tar.gz
7+
**/log.txt

PathContextReader.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

README.md

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@ This is an official implementation of the model described in:
55
[Uri Alon](http://urialon.cswp.cs.technion.ac.il), [Meital Zilberstein](http://www.cs.technion.ac.il/~mbs/), [Omer Levy](https://levyomer.wordpress.com) and [Eran Yahav](http://www.cs.technion.ac.il/~yahave/),
66
"code2vec: Learning Distributed Representations of Code", POPL'2019 [[PDF]](https://urialon.cswp.cs.technion.ac.il/wp-content/uploads/sites/83/2018/12/code2vec-popl19.pdf)
77

8-
_**October 2018** - the paper was accepted to [POPL'2019](https://popl19.sigplan.org)_!
8+
_**October 2018** - The paper was accepted to [POPL'2019](https://popl19.sigplan.org)_!
99

1010
_**April 2019** - The talk video is available [here](https://www.youtube.com/watch?v=EJ8okcxL2Iw)_.
1111

12+
_**July 2019** - Add `tf.keras` model implementation (see [here](#choosing-implementation-to-use))._
13+
1214
An **online demo** is available at [https://code2vec.org/](https://code2vec.org/).
1315

1416
This is a TensorFlow implementation, designed to be easy and useful in research,
1517
and for experimenting with new ideas in machine learning for code tasks.
1618
By default, it learns Java source code and predicts Java method names, but it can be easily extended to other languages,
1719
since the TensorFlow network is agnostic to the input programming language (see [Extending to other languages](#extending-to-other-languages).
1820
Contributions are welcome.
21+
This repo actually contains two model implementations. The 1st uses pure TensorFlow and the 2nd uses TensorFlow's Keras.
1922

2023
<center style="padding: 40px"><img width="70%" src="https://github.com/tech-srl/code2vec/raw/master/images/network.png" /></center>
2124

@@ -33,13 +36,18 @@ Table of Contents
3336
On Ubuntu:
3437
* [Python3](https://www.linuxbabe.com/ubuntu/install-python-3-6-ubuntu-16-04-16-10-17-04). To check if you have it:
3538
> python3 --version
36-
* TensorFlow - version 1.5 or newer ([install](https://www.tensorflow.org/install/install_linux)). To check TensorFlow version:
39+
* TensorFlow - version 2.0.0-beta1 ([install](https://www.tensorflow.org/install/install_linux)).
40+
To check TensorFlow version:
3741
> python3 -c 'import tensorflow as tf; print(tf.\_\_version\_\_)'
38-
* If you are using a GPU, you will need CUDA 9.0 ([download](https://developer.nvidia.com/cuda-90-download-archive))
42+
* If you are using a GPU, you will need CUDA 10.0
43+
([download](https://developer.nvidia.com/cuda-10.0-download-archive-base))
3944
as this is the version that is currently supported by TensorFlow. To check CUDA version:
4045
> nvcc --version
41-
* For GPU: cuDNN (>=7.0) ([download](http://developer.nvidia.com/cudnn))
42-
* For [creating a new dataset](#creating-and-preprocessing-a-new-java-dataset) or [manually examining a trained model](#step-4-manual-examination-of-a-trained-model) (any operation that requires parsing of a new code example) - [Java JDK](https://openjdk.java.net/install/)
46+
* For GPU: cuDNN (>=7.5) ([download](http://developer.nvidia.com/cudnn)) To check cuDNN version:
47+
> cat /usr/include/cudnn.h | grep CUDNN_MAJOR -A 2
48+
* For [creating a new dataset](#creating-and-preprocessing-a-new-java-dataset)
49+
or [manually examining a trained model](#step-4-manual-examination-of-a-trained-model)
50+
(any operation that requires parsing of a new code example) - [Java JDK](https://openjdk.java.net/install/)
4351

4452
## Quickstart
4553
### Step 0: Cloning this repository
@@ -124,46 +132,74 @@ To manually examine a trained model, run:
124132
```
125133
python3 code2vec.py --load models/java14_model/saved_model_iter8 --predict
126134
```
127-
After the model loads, follow the instructions and edit the file Input.java and enter a Java
135+
After the model loads, follow the instructions and edit the file [Input.java](Input.java) and enter a Java
128136
method or code snippet, and examine the model's predictions and attention scores.
129137

130138
## Configuration
131-
Changing hyper-parameters is possible by editing the file [common.py](common
132-
.py).
139+
Changing hyper-parameters is possible by editing the file
140+
[common.py](common.py).
133141

134142
Here are some of the parameters and their description:
135-
#### config.NUM_EPOCHS = 20
143+
#### config.NUM_TRAIN_EPOCHS = 20
136144
The max number of epochs to train the model. Stopping earlier must be done manually (kill).
137145
#### config.SAVE_EVERY_EPOCHS = 1
138146
After how many training iterations a model should be saved.
139-
#### config.BATCH_SIZE = 1024
147+
#### config.TRAIN_BATCH_SIZE = 1024
140148
Batch size in training.
141-
#### config.TEST_BATCH_SIZE = config.BATCH_SIZE
149+
#### config.TEST_BATCH_SIZE = config.TRAIN_BATCH_SIZE
142150
Batch size in evaluating. Affects only the evaluation speed and memory consumption, does not affect the results.
143-
#### config.READING_BATCH_SIZE = 1300 * 4
144-
The batch size of reading text lines to the queue that feeds examples to the network during training.
145-
#### config.NUM_BATCHING_THREADS = 2
146-
The number of threads enqueuing examples.
147-
#### config.BATCH_QUEUE_SIZE = 300000
148-
Max number of elements in the feeding queue.
149-
#### config.DATA_NUM_CONTEXTS = 200
150-
The number of contexts in a single example, as was created in preprocessing.
151+
#### config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
152+
Number of words with highest scores in $ y_hat $ to consider during prediction and evaluation.
153+
#### config.NUM_BATCHES_TO_LOG_PROGRESS = 100
154+
Number of batches (during training / evaluating) to complete between two progress-logging records.
155+
#### config.NUM_TRAIN_BATCHES_TO_EVALUATE = 100
156+
Number of training batches to complete between model evaluations on the test set.
157+
#### config.READER_NUM_PARALLEL_BATCHES = 4
158+
The number of threads enqueuing examples to the reader queue.
159+
#### config.SHUFFLE_BUFFER_SIZE = 10000
160+
Size of buffer in reader to shuffle example within during training.
161+
Bigger buffer allows better randomness, but requires more amount of memory and may harm training throughput.
162+
#### config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
163+
The buffer size (in bytes) of the CSV dataset reader.
164+
151165
#### config.MAX_CONTEXTS = 200
152166
The number of contexts to use in each example.
153-
#### config.WORDS_VOCAB_SIZE = 1301136
167+
#### config.MAX_TOKEN_VOCAB_SIZE = 1301136
154168
The max size of the token vocabulary.
155-
#### config.TARGET_VOCAB_SIZE = 261245
169+
#### config.MAX_TARGET_VOCAB_SIZE = 261245
156170
The max size of the target words vocabulary.
157-
#### config.PATHS_VOCAB_SIZE = 911417
171+
#### config.MAX_PATH_VOCAB_SIZE = 911417
158172
The max size of the path vocabulary.
159-
#### config.EMBEDDINGS_SIZE = 128
160-
Embedding size for tokens and paths.
173+
#### config.DEFAULT_EMBEDDINGS_SIZE = 128
174+
Default embedding size to be used for token and path if not specified otherwise.
175+
#### config.TOKEN_EMBEDDINGS_SIZE = config.EMBEDDINGS_SIZE
176+
Embedding size for tokens.
177+
#### config.PATH_EMBEDDINGS_SIZE = config.EMBEDDINGS_SIZE
178+
Embedding size for paths.
179+
#### config.CODE_VECTOR_SIZE = config.PATH_EMBEDDINGS_SIZE + 2 * config.TOKEN_EMBEDDINGS_SIZE
180+
Size of code vectors.
181+
#### config.TARGET_EMBEDDINGS_SIZE = config.CODE_VECTOR_SIZE
182+
Embedding size for target words.
161183
#### config.MAX_TO_KEEP = 10
162184
Keep this number of newest trained versions during training.
185+
#### config.DROPOUT_KEEP_RATE = 0.75
186+
Dropout rate used during training.
187+
#### config.SEPARATE_OOV_AND_PAD = False
188+
Whether to treat `<OOV>` and `<PAD>` as two different special tokens whenever possible.
163189

164190
## Features
165191
Code2vec supports the following features:
166192

193+
### Choosing implementation to use
194+
This repo comes with two model implementations:
195+
(i) uses pure TensorFlow (written in [tensorflow_model.py](tensorflow_model.py));
196+
(ii) uses TensorFlow's Keras (written in [keras_model.py](keras_model.py)).
197+
The default implementation used by `code2vec.py` is the pure TensorFlow.
198+
To explicitly choose the desired implementation to use, specify `--framework tensorflow` or `--framework keras`
199+
as an additional argument when executing the script `code2vec.py`.
200+
Particularly, this argument can be added to each one of the usage examples (of `code2vec.py`) detailed in this file.
201+
Note that in order to load a trained model (from file), one should use the same implementation used during its training.
202+
167203
### Releasing the model
168204
If you wish to keep a trained model for inference only (without the ability to continue training it) you can
169205
release the model using:

code2vec.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,38 @@
1-
from common import Config, VocabType
2-
from argparse import ArgumentParser
1+
from vocabularies import VocabType
2+
from config import Config
33
from interactive_predict import InteractivePredictor
4-
from model import Model
5-
import sys
4+
from model_base import Code2VecModelBase
65

7-
if __name__ == '__main__':
8-
parser = ArgumentParser()
9-
parser.add_argument("-d", "--data", dest="data_path",
10-
help="path to preprocessed dataset", required=False)
11-
parser.add_argument("-te", "--test", dest="test_path",
12-
help="path to test file", metavar="FILE", required=False)
136

14-
is_training = '--train' in sys.argv or '-tr' in sys.argv
15-
parser.add_argument("-s", "--save", dest="save_path",
16-
help="path to save file", metavar="FILE", required=False)
17-
parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
18-
help="path to save file", metavar="FILE", required=False)
19-
parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
20-
help="path to save file", metavar="FILE", required=False)
21-
parser.add_argument("-l", "--load", dest="load_path",
22-
help="path to save file", metavar="FILE", required=False)
23-
parser.add_argument('--save_w2v', dest='save_w2v', required=False,
24-
help="save word (token) vectors in word2vec format")
25-
parser.add_argument('--save_t2v', dest='save_t2v', required=False,
26-
help="save target vectors in word2vec format")
27-
parser.add_argument('--export_code_vectors', action='store_true', required=False,
28-
help="export code vectors for the given examples")
29-
parser.add_argument('--release', action='store_true',
30-
help='if specified and loading a trained model, release the loaded model for a lower model '
31-
'size.')
32-
parser.add_argument('--predict', action='store_true')
33-
args = parser.parse_args()
7+
def load_model_dynamically(config: Config) -> Code2VecModelBase:
8+
assert config.DL_FRAMEWORK in {'tensorflow', 'keras'}
9+
if config.DL_FRAMEWORK == 'tensorflow':
10+
from tensorflow_model import Code2VecModel
11+
elif config.DL_FRAMEWORK == 'keras':
12+
from keras_model import Code2VecModel
13+
return Code2VecModel(config)
14+
15+
16+
if __name__ == '__main__':
17+
config = Config(set_defaults=True, load_from_args=True, verify=True)
3418

35-
config = Config.get_default_config(args)
19+
model = load_model_dynamically(config)
20+
config.log('Done creating code2vec model')
3621

37-
model = Model(config)
38-
print('Created model')
39-
if config.TRAIN_PATH:
22+
if config.is_training:
4023
model.train()
41-
if args.save_w2v is not None:
42-
model.save_word2vec_format(args.save_w2v, source=VocabType.Token)
43-
print('Origin word vectors saved in word2vec text format in: %s' % args.save_w2v)
44-
if args.save_t2v is not None:
45-
model.save_word2vec_format(args.save_t2v, source=VocabType.Target)
46-
print('Target word vectors saved in word2vec text format in: %s' % args.save_t2v)
47-
if config.TEST_PATH and not args.data_path:
24+
if config.SAVE_W2V is not None:
25+
model.save_word2vec_format(config.SAVE_W2V, VocabType.Token)
26+
config.log('Origin word vectors saved in word2vec text format in: %s' % config.SAVE_W2V)
27+
if config.SAVE_T2V is not None:
28+
model.save_word2vec_format(config.SAVE_T2V, VocabType.Target)
29+
config.log('Target word vectors saved in word2vec text format in: %s' % config.SAVE_T2V)
30+
if config.is_testing and not config.is_training:
4831
eval_results = model.evaluate()
4932
if eval_results is not None:
50-
results, precision, recall, f1 = eval_results
51-
print(results)
52-
print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1))
53-
if args.predict:
33+
config.log(
34+
str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
35+
if config.PREDICT:
5436
predictor = InteractivePredictor(config, model)
5537
predictor.predict()
5638
model.close_session()

0 commit comments

Comments
 (0)