We propose a Weight Diffusion (W-Diff) approach, which is specialized for the evolving domain generalization (EDG) in the domain-incremental setting. W-Diff capitalizes on the strong modeling ability of diffusion models to capture the evolving pattern of optimized classifiers across domains.
-
The code is implemented with
Python 3.7.16and run onNVIDIA GeForce RTX 4090. To try out this project, it is recommended to set up a virtual environment first.# Step-by-step installation conda create --name wdiff python=3.7.16 conda activate wdiff # this installs the right pip and dependencies for the fresh python conda install -y ipython pip # install torch, torchvision and torchaudio pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html # this installs required packages pip install -r requirements.txt
- Download yearbook.pkl
- Download fmow.pkl and fmow_v1.1.tar.gz
- Download huffpost.pkl
- Download arxiv.pkl
- ONP and 2-Moons are provided in the "datasets" folder.
- rmnist will automatically download while running the code.
The data folder should be structured as follows:
``` ├── datasets/ │ ├── yearbook/ | | ├── yearbook.pkl │ ├── rmnist/ | | ├── MNIST/ | | ├── rmnist.pkl │ ├── ONP/ | | ├── processed/ │ ├── Moons/ | | ├── processed/ │ ├── huffpost/ | | ├── huffpost.pkl │ ├── fMoW/ | | ├── fmow_v1.1/ | | | |── images/ | | |—— fmow.pkl │ ├── arxiv/ | | ├── arxiv.pkl ``` -
Training and testing together:
# running for yearbook dataset: python3 main.py --cfg ./configs/eval_fix/cfg_yearbook.yaml device 0 # running for rmnist dataset: python3 main.py --cfg ./configs/eval_fix/cfg_rmnist.yaml device 1 # running for fmow dataset: python3 main.py --cfg ./configs/eval_fix/cfg_fmow.yaml device 2 # running for 2-Moons dataset: python3 main.py --cfg ./configs/eval_fix/cfg_moons.yaml device 3 # running for ONP dataset: python3 main.py --cfg ./configs/eval_fix/cfg_onp.yaml device 4 # running for huffpost dataset: python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5 # running for arxiv dataset: python3 main.py --cfg ./configs/eval_fix/cfg_arxiv.yaml device 6
If you meet the "OSError: Can't load tokenizer for 'bert-base-uncased'." when running code on the Huffpost and Arxiv datasets, you can try to add HF_ENDPOINT=https://hf-mirror.com before the python commands. For example,
HF_ENDPOINT=https://hf-mirror.com python3 main.py --cfg ./configs/eval_fix/cfg_huffpost.yaml device 5 -
Testing with saved model checkpoints:
You can download the models trained by W-Diff here and put them into
<root_dir>/checkpoints/.# evaluating on ONP dataset python3 main_test_only.py --cfg ./configs/eval_fix/cfg_onp.yaml --model_path 'abs_path_of_onp_model.pkl' device 5
This project is mainly based on the open-source project: Wild-Time, EvoS and LDM. We thank the authors for making the source code publicly available.
If you find this work helpful to your research, please consider citing the paper:
@inproceedings{xie2024wdiff, title={Weight Diffusion for Future: Learn to Generalize in Non-Stationary Environments}, author={Mixue Xie, Shuang Li, Binhui Xie, Chi Harold Liu, Jian Liang, Zixun Sun, Ke Feng, Chengwei Zhu}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2024} }