Skip to content

MegEngine/MegDiffusion

Repository files navigation

MegDiffusion

MegEngine implementation of Diffusion Models (in early development).

Current maintainer: @MegChai

Usage

Infer with pre-trained models

Now users can use megengine.hub to get pre-trained models directly:

import megengine repo_info = "MegEngine/MegDiffusion:main" megengine.hub.list(repo_info) preatrained_model = "ddpm_cifar10_ema_converted" megengine.hub.help(repo_info, preatrained_model) model = megengine.hub.load(repo_info, preatrained_model, pretrained=True) model.eval()

Note that using megengine.hub will download the whole repository from it's host or using cache.

If you have downloaded or installed MegDiffusion, you can get pre-trained models from pretrain module.

from megdiffusion.model import pretrain model = pretrain.ddpm_cifar10_ema_converted(pretrained=True) model.eval()

The sample script shows how to generate 64 CIFAR10-like images and make a grid of them:

python3 -m megdiffusion.pipeline.ddpm.sample

Train from scratch

  • Take DDPM CIFAR10 for example:

    python3 -m megdiffusion.pipeline.ddpm.train \ --config ./configs/ddpm/cifar10.yaml
  • [Optional] Overwrite arguments:

    python3 -m megdiffusion.pipeline.ddpm.train \ --config ./configs/ddpm/cifar10.yaml \ --logdir ./path/to/logdir \ --parallel --resume

See python3 -m megdiffusion.pipeline.ddpm.train --help for more information. For other options like batch_size, we recommend modifying and backing up them in the yaml file.

If you want to sample with model trained by yourself (not the pre-trained model):

python3 -m megdiffusion.pipeline.ddpm.sample --nopretrain \ --logdir ./path/to/logdir \ --config ./configs/ddpm/cifar10.yaml # Coule be your customed file

Development

python3 -m pip install -r requirements.txt python3 -m pip install -v -e .

Develop this project with a new branch locally, remember to add necessary test codes. If finished, submit Pull Request to the main branch then just wait for review.

Acknowledgment

The following open-sourced projects was referenced here:

Thanks to people including @gaohuazuo, @xxr3376, @P2Oileen and other contributors for support in this project. The R&D platform and the resources required for the experiment are provided by MEGVII Inc. The deep learning framework used in this project is MegEngine -- a magic weapon.

Citations

@article{ho2020denoising, title = {Denoising Diffusion Probabilistic Models}, author = {Jonathan Ho and Ajay Jain and Pieter Abbeel}, year = {2020}, eprint = {2006.11239}, archivePrefix = {arXiv}, primaryClass = {cs.LG} } 
@article{DBLP, title = {Improved Denoising Diffusion Probabilistic Models}, author = {Alex Nichol and Prafulla Dhariwal}, year = {2021}, url = {https://arxiv.org/abs/2102.09672}, eprinttype = {arXiv}, eprint = {2102.09672}, } 

About

MegEngine implementation of Diffusion Models.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •  

Languages