|
| 1 | +# A universal probabilistic spike count model reveals ongoing modulation of neural variability (NeurIPS 2021) |
| 2 | + |
| 3 | + |
| 4 | +## Overview |
| 5 | + |
| 6 | +This is the code repository for this [paper](https://www.biorxiv.org/content/10.1101/2021.06.27.450063v2). |
| 7 | +Models are implemented in Python with dependencies on libraries listed below at the end. |
| 8 | +We also include a neural data analysis library that was written for constructing scalable neural encoding models using a modern deep learning framework. |
| 9 | +The baseline models, along with the universal count model proposed in our work, are implemented in the library and can be used for analysis of other neural datasets. |
| 10 | + |
| 11 | +<p align="center"> |
| 12 | +<img src="./plots/schematic.png" width="800"/> |
| 13 | +</p> |
| 14 | + |
| 15 | + |
| 16 | + |
| 17 | + |
| 18 | +## Reproducing the results in the paper |
| 19 | + |
| 20 | + |
| 21 | +#### 1. cd into ./scripts/ |
| 22 | +Here is where all the code for fitting models, analysis and plotting is located. |
| 23 | + |
| 24 | + |
| 25 | +#### 2. (Optional) Run the synthetic_data notebook to generate data from synthetic populations |
| 26 | +This generates the two synthetic populations and saves them into ./data/, both generated spike counts and behaviour as well as the encoding models. |
| 27 | +Note that the population data used in the paper has been included in ./data/, running this script will overwrite those files! |
| 28 | + |
| 29 | + |
| 30 | +#### 3. Run the scripts to fit models |
| 31 | + |
| 32 | +##### Command line format |
| 33 | +Run commands based on the following formats into the command line: |
| 34 | +``` |
| 35 | +python3 validation.py --cv -1 2 5 8 --gpu 0 --modes 0 --datatype 0 --ncvx 2 --lr 1e-2 --lr_2 1e-3 --batchsize 10000 |
| 36 | +``` |
| 37 | +This runs a model of mode 0 on synthetic data, with `--cv` indicating which cross-validation fold to leave out for validation (-1 indicates using all data) and `--gpu` indicating the GPU device to run on (if available). |
| 38 | +Line 188 in validation.py gives the definition of all modes (numbered 0 to 8), in particular the likelihood (1st element of tuple) and the input space (2d element of tuple) are specified. |
| 39 | +Note there is a 10-fold split of the data, hence the cv trial numbers can go from -1 to 9. |
| 40 | +`lr` and `lr_2` indicate the learning rates, with `lr_2` for kernel and variational standard deviations (lower for latent models as described in the paper). |
| 41 | +The flag `--ncvx` refers to the number of runs to do (selecting the best fit model after completion to save). |
| 42 | +One can also specify `--batchsize`, which can speed up training when larger depending on the memory capacity of the hardware used. |
| 43 | +For validation.py, the flag `--datatype` can be 0 (heteroscedastic Conway-Maxwell-Poisson) or 1 (modulated Poisson). |
| 44 | +``` |
| 45 | +python3 HDC.py --cv -1 1 3 6 --gpu 0 --modes 0 --ncvx 2 --lr 1e-2 --lr_2 1e-3 --binsize 40 |
| 46 | +``` |
| 47 | +Similarly, this runs a model of mode 0 on head direction cell data, with head direction cell data binned into 40 ms bins set by `--binsize`. |
| 48 | +Line 108 in HDC.py gives the definition of all modes for the head direction cell models (numbered 0 to 11). |
| 49 | +All possible flags and their default values can be seen in the validation.py and HDC.py scripts. |
| 50 | +The file models.py defines the encoding models and uses the library code (neuroprob) to implement and run these probabilistic models. |
| 51 | + |
| 52 | +In terms of neural data, the synthetic populatio data used in the paper and the head direction cell data is included in the ./data/ folder. |
| 53 | +All required modes in the analysis notebooks can be seen in the code as it loads trained models. |
| 54 | +Note that there are separate notebooks for synthetic (validation) and real (HDC) datasets. |
| 55 | +All trained models are stored in the ./checkpoint/ folder. |
| 56 | + |
| 57 | + |
| 58 | +##### Experiments in the paper |
| 59 | +- Synthetic data |
| 60 | + |
| 61 | +`python3 validation.py --cv -1 2 5 8 --gpu 0 --modes 0 1 2 3 --datatype 0 --ncvx 2 --lr 1e-2` (regression models) |
| 62 | + |
| 63 | +`python3 validation.py --cv -1 2 5 8 --gpu 0 --modes 4 5 6 7 --datatype 0 --ncvx 3 --lr 1e-2 --lr_2 1e-3` (latent variable models) |
| 64 | + |
| 65 | +`python3 validation.py --cv -1 2 5 8 --gpu 0 --modes 0 2 8 --datatype 1 --ncvx 2 --lr 1e-2` (capturing noise correlations and single neuron variability) |
| 66 | + |
| 67 | +- Head direction cell data |
| 68 | + |
| 69 | +`python3 HDC.py --cv -1 1 2 3 5 6 8 --gpu 0 --modes 0 1 4 --ncvx 2 --lr 1e-2 --binsize 40` ( regression with different likelihoods) |
| 70 | + |
| 71 | +`python3 HDC.py --cv -1 1 2 3 5 6 8 --gpu 0 --modes 2 3 --ncvx 2 --lr 1e-2 --binsize 40` (regression with different regressors) |
| 72 | + |
| 73 | +`python3 HDC.py --cv -1 1 2 3 5 6 8 --gpu 0 --modes 5 6 7 8 --ncvx 3 --lr 1e-2 --lr_2 1e-3 --binsize 40` (joint latent-observed models) |
| 74 | + |
| 75 | +`python3 HDC.py --cv -1 1 2 3 5 6 8 --gpu 0 --modes 9 10 11 --ncvx 3 --lr 3e-2 --lr_2 5e-3 --binsize 100` (latent variable models) |
| 76 | + |
| 77 | +If you wish to run different modes or cross-validation runs grouped together above in parallel, run the command several times with only a single mode or cv trial each time. |
| 78 | + |
| 79 | + |
| 80 | +#### 4. Run the analysis notebooks to analyze the data |
| 81 | +By running the analysis notebooks, we reproduce the plotting data for figures in the paper. |
| 82 | +Intermediate files (pickled) will be stored in the ./checkpoint/ folder. |
| 83 | + |
| 84 | + |
| 85 | +#### 5. Run the plotting notebooks |
| 86 | +This loads the analysis results and plots paper figures in .pdf and .svg formats, exported to the ./output/ folder. |
| 87 | + |
| 88 | + |
| 89 | + |
| 90 | + |
| 91 | +## Neural data analysis library |
| 92 | + |
| 93 | +Here we present a short description of a preliminary version of the neural data analysis library used (called neuroprob) to facilitate constructing neural encoding models. |
| 94 | +The list below shows what has implemented for use so far, see the models.py file for an example of code utilizing the library. |
| 95 | + |
| 96 | + |
| 97 | +#### Primitives |
| 98 | + |
| 99 | +There are three kinds of objects that form the building blocks: |
| 100 | +1. Input group *p(X,Z)* and *q(Z)* |
| 101 | +2. Mapping *p(F|X,Z)* |
| 102 | +3. Likelihood *p(Y|F)* |
| 103 | + |
| 104 | +The overal generative model is specified along with the variational posterior through these primitives. |
| 105 | +Input groups can contain observed and latent variables, with different priors one can put onto the latent variables. |
| 106 | + |
| 107 | + |
| 108 | +#### Models implemented |
| 109 | + |
| 110 | +* Linear-nonlinear and GP mappings |
| 111 | +* RNNs |
| 112 | +* LVMs |
| 113 | + - Toroidal latent space priors ([Manifold GPLVM](https://arxiv.org/abs/2006.07429)) |
| 114 | + - AR(1) temporal prior on latents |
| 115 | +* GLM filters |
| 116 | + - spike-history couplings |
| 117 | + - spike-spike couplings |
| 118 | + - stimulus history |
| 119 | +* Inhomogenenous renewal point processes |
| 120 | + - Gamma |
| 121 | + - Inverse Gaussian |
| 122 | + - Log Normal |
| 123 | +* Count process likelihoods |
| 124 | + - Poisson |
| 125 | + - Zero-inflated Poisson |
| 126 | + - Negative binomial |
| 127 | + - Conway-Maxwell-Poisson |
| 128 | + - Universal (this work) |
| 129 | +* Gaussian likelihoods |
| 130 | + - Univariate |
| 131 | + |
| 132 | + |
| 133 | + |
| 134 | + |
| 135 | +## Dependencies: |
| 136 | +- [PyTorch](https://pytorch.org/) version >= 1.7 |
| 137 | +- [NumPy](https://numpy.org/) |
| 138 | +- [SciPy](https://scipy.org/) |
| 139 | +- [tqdm](https://tqdm.github.io/) for visualizing fitting/training progress |
| 140 | +- [Daft](https://docs.daft-pgm.org/en/latest/) to plot graphical model components |
0 commit comments