PyTorch's implementation of code2seq model.
You can easily install model through the PIP:
pip install code2seq
To prepare your own dataset with a storage format supported by this implementation, use on the following:
- Original dataset preprocessing from vanilla repository
astminer
: the tool for mining path-based representation and more with multiple language support.PSIMiner
: the tool for extracting PSI trees from IntelliJ Platform and creating datasets from them.
Dataset (with link) | Checkpoint | # epochs | F1-score | Precision | Recall | ChrF |
---|---|---|---|---|---|---|
Java-small | link | 11 | 41.49 | 54.26 | 33.59 | 30.21 |
Java-med | link | 10 | 48.17 | 58.87 | 40.76 | 42.32 |
The model is fully configurable by standalone YAML file. Navigate to config directory to see examples of configs.
Model training may be done via PyTorch Lightning trainer. See it documentation for more information.
from argparse import ArgumentParser from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer from code2seq.data.path_context_data_module import PathContextDataModule from code2seq.model import Code2Seq def train(config: DictConfig): # Define data module data_module = PathContextDataModule(config.data_folder, config.data) # Define model model = Code2Seq( config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing ) # Define hyper parameters trainer = Trainer(max_epochs=config.train.n_epochs) # Train model trainer.fit(model, datamodule=data_module) if __name__ == "__main__": __arg_parser = ArgumentParser() __arg_parser.add_argument("config", help="Path to YAML configuration file", type=str) __args = __arg_parser.parse_args() __config = OmegaConf.load(__args.config) train(__config)