PyTorch re-implementatoin of the Deep Markov Model (https://arxiv.org/abs/1609.09869)
@inproceedings{10.5555/3298483.3298543, author = {Krishnan, Rahul G. and Shalit, Uri and Sontag, David}, title = {Structured Inference Networks for Nonlinear State Space Models}, year = {2017}, publisher = {AAAI Press}, booktitle = {Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence}, pages = {2101–2109}, numpages = {9}, location = {San Francisco, California, USA}, series = {AAAI'17} } Note:
- The calculated metrics in
model/metrics.pydo not match those reported in the paper, which might be (more likely) due to differences in parameter settings and metric calculations. - The current implementatoin only supports JSB polyphonic music dataset.
Refer to the branch factorial-dmm for a model described as Factorial DMM. The other branch refractor is trying to improve readability with increased options of models (DOCUMENT NOT UPDATED YET!).
Training the model with the default config.json:
python train.py -c config.json add -i flag to specifically name the experiment that is to be saved under saved/.
This file specifies parameters and configurations. Below explains some key parameters.
A careful fine-tuning of the parameters seems necessary to match the reported performances.
{ "arch": { "type": "DeepMarkovModel", "args": { "input_dim": 88, "z_dim": 100, "emission_dim": 100, "transition_dim": 200, "rnn_dim": 600, "rnn_type": "lstm", "rnn_layers": 1, "rnn_bidirection": false, // condition z_t on both directions of inputs, // manually turn off `reverse_rnn_input` if True // (this is minor and could be quickly fixed) "use_embedding": true, // use extra linear layer before RNN "orthogonal_init": true, // orthogonal initialization for RNN "gated_transition": true, // use linear/non-linear gated transition "train_init": false, // make z0 trainble "mean_field": false, // use mean-field posterior q(z_t | x) "reverse_rnn_input": true, // condition z_t on future inputs "sample": true // sample during reparameterization } }, "optimizer": { "type": "Adam", "args":{ "lr": 0.0008, // default value from the author's source code "weight_decay": 0.0, // debugging stage indicates that 1.0 prevents training "amsgrad": true, "betas": [0.9, 0.999] } }, "trainer": { "epochs": 3000, "overfit_single_batch": false, // overfit one single batch for debug "save_dir": "saved/", "save_period": 500, "verbosity": 2, "monitor": "min val_loss", "early_stop": 100, "tensorboard": true, "min_anneal_factor": 0.0, "anneal_update": 5000 } }- Project template brought from the pytorch-template
- The original source code in Theano
- PyTorch implementation in Pyro framework
- Another PyTorch implementation by @guxd
- fine-tune to match the reported performances in the paper
- correct (if any) errors in metric calculation,
model/metric.py - optimize important sampling