Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ and for experimenting with new ideas in machine learning for code tasks.
By default, it learns Java source code and predicts Java method names, but it can be easily extended to other languages,
since the TensorFlow network is agnostic to the input programming language (see [Extending to other languages](#extending-to-other-languages).
Contributions are welcome.
This repo actually contains two model implementations. The 1st uses pure TensorFlow and the 2nd uses TensorFlow's Keras.
This repo actually contains two model implementations. The 1st uses pure TensorFlow and the 2nd uses TensorFlow's Keras ([more details](#choosing-implementation-to-use)).

<center style="padding: 40px"><img width="70%" src="https://github.com/tech-srl/code2vec/raw/master/images/network.png" /></center>

Expand Down Expand Up @@ -103,7 +103,7 @@ This model weights more than twice than the stripped version, and it is recommen
To train a model from scratch:
* Edit the file [train.sh](train.sh) to point it to the right preprocessed data. By default,
it points to our "java14m" dataset that was preprocessed in the previous step.
* Before training, you can edit the configuration hyper-parameters in the file [common.py](common.py),
* Before training, you can edit the configuration hyper-parameters in the file [config.py](config.py),
as explained in [Configuration](#configuration).
* Run the [train.sh](train.sh) script:
```
Expand All @@ -114,7 +114,7 @@ source train.sh
1. By default, the network is evaluated on the validation set after every training epoch.
2. The newest 10 versions are kept (older are deleted automatically). This can be changed, but will be more space consuming.
3. By default, the network is training for 20 epochs.
These settings can be changed by simply editing the file [common.py](common.py).
These settings can be changed by simply editing the file [config.py](config.py).
Training on a Tesla v100 GPU takes about 50 minutes per epoch.
Training on Tesla K80 takes about 4 hours per epoch.

Expand All @@ -137,7 +137,7 @@ method or code snippet, and examine the model's predictions and attention scores

## Configuration
Changing hyper-parameters is possible by editing the file
[common.py](common.py).
[config.py](config.py).

Here are some of the parameters and their description:
#### config.NUM_TRAIN_EPOCHS = 20
Expand Down
2 changes: 1 addition & 1 deletion model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def save(self, model_save_path=None):
model_save_path = self.config.MODEL_SAVE_PATH
model_save_dir = '/'.join(model_save_path.split('/')[:-1])
if not os.path.isdir(model_save_dir):
os.mkdir(model_save_dir)
os.makedirs(model_save_dir, exist_ok=True)
self.vocabs.save(self.config.get_vocabularies_path_from_model_path(model_save_path))
self._save_inner_model(model_save_path)

Expand Down
6 changes: 3 additions & 3 deletions tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def train(self):
multi_batch_start_time = time.time()
if batch_num % num_batches_to_save_and_eval == 0:
epoch_num = int((batch_num / num_batches_to_save_and_eval) * self.config.SAVE_EVERY_EPOCHS)
save_path = self.config.MODEL_SAVE_PATH + '_iter' + str(epoch_num)
self._save_inner_model(save_path)
self.log('Saved after %d epochs in: %s' % (epoch_num, save_path))
model_save_path = self.config.MODEL_SAVE_PATH + '_iter' + str(epoch_num)
self.save(model_save_path)
self.log('Saved after %d epochs in: %s' % (epoch_num, model_save_path))
evaluation_results = self.evaluate()
evaluation_results_str = (str(evaluation_results).replace('topk', 'top{}'.format(
self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
Expand Down