Skip to content
/ W-Diff Public

[NeurIPS 2024] official implementation of Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments

Notifications You must be signed in to change notification settings

BIT-DA/W-Diff

Repository files navigation

Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments [NeurIPS 2024]

openreview    Poster   

Overview

We propose a Weight Diffusion (W-Diff) approach, which is specialized for the evolving domain generalization (EDG) in the domain-incremental setting. W-Diff capitalizes on the strong modeling ability of diffusion models to capture the evolving pattern of optimized classifiers across domains. image

Prerequisites Installation

  • The code is implemented with Python 3.7.16 and run on NVIDIA GeForce RTX 4090. To try out this project, it is recommended to set up a virtual environment first.

    # Step-by-step installation conda create --name wdiff python=3.7.16 conda activate wdiff # this installs the right pip and dependencies for the fresh python conda install -y ipython pip # install torch, torchvision and torchaudio pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html # this installs required packages pip install -r requirements.txt

Datasets Preparation

The data folder should be structured as follows:

``` ├── datasets/ │ ├── yearbook/ | | ├── yearbook.pkl │ ├── rmnist/ | | ├── MNIST/ | | ├── rmnist.pkl │ ├── ONP/ | | ├── processed/ │ ├── Moons/ | | ├── processed/ │ ├── huffpost/ | | ├── huffpost.pkl │ ├── fMoW/ | | ├── fmow_v1.1/ | | | |── images/ | | |—— fmow.pkl │ ├── arxiv/ | | ├── arxiv.pkl ``` 

Code Running

  • Training and testing together:

    # running for yearbook dataset: python3 main.py --cfg ./configs/eval_fix/cfg_yearbook.yaml device 0 # running for rmnist dataset: python3 main.py --cfg ./configs/eval_fix/cfg_rmnist.yaml device 1 # running for fmow dataset: python3 main.py --cfg ./configs/eval_fix/cfg_fmow.yaml device 2 # running for 2-Moons dataset: python3 main.py --cfg ./configs/eval_fix/cfg_moons.yaml device 3 # running for ONP dataset: python3 main.py --cfg ./configs/eval_fix/cfg_onp.yaml device 4 # running for huffpost dataset: python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5 # running for arxiv dataset: python3 main.py --cfg ./configs/eval_fix/cfg_arxiv.yaml device 6 

If you meet the "OSError: Can't load tokenizer for 'bert-base-uncased'." when running code on the Huffpost and Arxiv datasets, you can try to add HF_ENDPOINT=https://hf-mirror.com before the python commands. For example,

HF_ENDPOINT=https://hf-mirror.com python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5 
  • Testing with saved model checkpoints:

    You can download the models trained by W-Diff here and put them into <root_dir>/checkpoints/.

    # evaluating on ONP dataset python3 main_test_only.py --cfg ./configs/eval_fix/cfg_onp.yaml --model_path 'abs_path_of_onp_model.pkl' device 5

Acknowledgments

This project is mainly based on the open-source project: Wild-Time, EvoS and LDM. We thank the authors for making the source code publicly available.

Citation

If you find this work helpful to your research, please consider citing the paper:

@inproceedings{xie2024wdiff, title={Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments}, author={Mixue Xie, Shuang Li, Binhui Xie, Chi Harold Liu, Jian Liang, Zixun Sun, Ke Feng, Chengwei Zhu}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2024} }

About

[NeurIPS 2024] official implementation of Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages