Skip to content

Commit 253e793

Browse files
authored
Merge pull request AI-Hypercomputer#1 from entrpn/add_sd_2_base
Add sd 2 base
2 parents 0221228 + fdbace3 commit 253e793

File tree

14 files changed

+508
-13
lines changed

14 files changed

+508
-13
lines changed

.gitignore

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Initially taken from Github's Python gitignore file
2+
3+
# Byte-compiled / optimized / DLL files
4+
__pycache__/
5+
*.py[cod]
6+
*$py.class
7+
8+
# C extensions
9+
*.so
10+
11+
# tests and logs
12+
tests/fixtures/cached_*_text.txt
13+
logs/
14+
lightning_logs/
15+
lang_code_data/
16+
image_*.png
17+
train-smoke-test/
18+
19+
# Distribution / packaging
20+
.Python
21+
build/
22+
develop-eggs/
23+
dist/
24+
downloads/
25+
eggs/
26+
.eggs/
27+
lib/
28+
lib64/
29+
parts/
30+
sdist/
31+
var/
32+
wheels/
33+
*.egg-info/
34+
.installed.cfg
35+
*.egg
36+
MANIFEST
37+
38+
# PyInstaller
39+
# Usually these files are written by a python script from a template
40+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
41+
*.manifest
42+
*.spec
43+
44+
# Installer logs
45+
pip-log.txt
46+
pip-delete-this-directory.txt
47+
48+
# Unit test / coverage reports
49+
htmlcov/
50+
.tox/
51+
.nox/
52+
.coverage
53+
.coverage.*
54+
.cache
55+
nosetests.xml
56+
coverage.xml
57+
*.cover
58+
.hypothesis/
59+
.pytest_cache/
60+
61+
# Translations
62+
*.mo
63+
*.pot
64+
65+
# Django stuff:
66+
*.log
67+
local_settings.py
68+
db.sqlite3
69+
70+
# Flask stuff:
71+
instance/
72+
.webassets-cache
73+
74+
# Scrapy stuff:
75+
.scrapy
76+
77+
# Sphinx documentation
78+
docs/_build/
79+
80+
# PyBuilder
81+
target/
82+
83+
# Jupyter Notebook
84+
.ipynb_checkpoints
85+
86+
# IPython
87+
profile_default/
88+
ipython_config.py
89+
90+
# pyenv
91+
.python-version
92+
93+
# celery beat schedule file
94+
celerybeat-schedule
95+
96+
# SageMath parsed files
97+
*.sage.py
98+
99+
# Environments
100+
.env
101+
.venv
102+
env/
103+
venv/
104+
ENV/
105+
env.bak/
106+
venv.bak/
107+
108+
# Spyder project settings
109+
.spyderproject
110+
.spyproject
111+
112+
# Rope project settings
113+
.ropeproject
114+
115+
# mkdocs documentation
116+
/site
117+
118+
# mypy
119+
.mypy_cache/
120+
.dmypy.json
121+
dmypy.json
122+
123+
# Pyre type checker
124+
.pyre/
125+
126+
# vscode
127+
.vs
128+
.vscode
129+
130+
# Pycharm
131+
.idea
132+
133+
# TF code
134+
tensorflow_code
135+
136+
# Models
137+
proc_data
138+
139+
# examples
140+
runs
141+
/runs_old
142+
/wandb
143+
/examples/runs
144+
/examples/**/*.args
145+
/examples/rag/sweep
146+
147+
# data
148+
/data
149+
serialization_dir
150+
151+
# emacs
152+
*.*~
153+
debug.env
154+
155+
# vim
156+
.*.swp
157+
158+
#ctags
159+
tags
160+
161+
# pre-commit
162+
.pre-commit*
163+
164+
# .lock
165+
*.lock
166+
167+
# DS_Store (MacOS)
168+
.DS_Store
169+
# RL pipelines may produce mp4 outputs
170+
*.mp4
171+
172+
# dependencies
173+
/transformers
174+
175+
# ruff
176+
.ruff_cache
177+
178+
wandb

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ MaxDiffusion is a Latent Diffusion model written in pure Python/Jax and targetin
2424
We encourage users to start by experimenting with MaxDiffusion out of the box and then fork and modify MaxDiffusion to meet their needs.
2525

2626
MaxDiffusion supports
27+
* Stable Diffusion 2 base (training and inference)
2728
* Stable Diffusion 2.1 (training and inference)
2829
* Stable Diffusion XL (inference).
2930

@@ -50,7 +51,7 @@ pip3 install -e .
5051
```
5152
4. After installation completes, run training with the command:
5253
```bash
53-
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base.yml run_name="my_run" base_output_directory="gs://your-bucket/"
54+
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" base_output_directory="gs://your-bucket/"
5455
```
5556
5. If you want to generate images, you can do it as follows.
5657
- Stable Diffusion 2.1

src/maxdiffusion/configs/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Model Configs
2+
3+
This directory contains model configuration for different Stable Diffusion models.
4+
5+
## Stable Diffusion 2.1
6+
7+
base21.yml - used for training and inference using [stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
8+
9+
## Stable Diffusion 2 Base
10+
11+
base_2_base.yml - used for training and inference using [stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base)
12+
13+
base_2_base_inference.yml - used for inference after running a training loop using the saved checkpoint in base_2_base.yml's config `output_dir`.
14+
15+
## Stable Diffusion XL
16+
17+
base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
File renamed without changes.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This sentinel is a reminder to choose a real run name.
16+
run_name: ''
17+
18+
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
19+
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
20+
gcs_metrics: True
21+
log_period: 100
22+
23+
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base'
24+
revision: 'main'
25+
dtype: 'bfloat16'
26+
# Set true to load weights from pytorch
27+
from_pt: True
28+
split_head_dim: True
29+
30+
# Output directory
31+
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
32+
base_output_directory: ""
33+
34+
# Parallelism
35+
mesh_axes: ['data', 'fsdp', 'tensor']
36+
37+
# batch : batch dimension of data and activations
38+
# hidden :
39+
# embed : attention qkv dense layer hidden dim named as embed
40+
# heads : attention head dim = num_heads * head_dim
41+
# length : attention sequence length
42+
# temb_in : dense.shape[0] of resnet dense before conv
43+
# out_c : dense.shape[1] of resnet dense before conv
44+
# out_channels : conv.shape[-1] activation
45+
# keep_1 : conv.shape[0] weight
46+
# keep_2 : conv.shape[1] weight
47+
# conv_in : conv.shape[2] weight
48+
# conv_out : conv.shape[-1] weight
49+
logical_axis_rules: [
50+
['batch', 'data'],
51+
['activation_batch', 'data'],
52+
['activation_length', 'fsdp'],
53+
['out_channels', 'fsdp'],
54+
['conv_out', 'fsdp'],
55+
['length', 'fsdp']
56+
]
57+
data_sharding: [['data', 'fsdp', 'tensor']]
58+
59+
# One axis for each parallelism type may hold a placeholder (-1)
60+
# value to auto-shard based on available slices and devices.
61+
# By default, product of the DCN axes should equal number of slices
62+
# and product of the ICI axes should equal number of devices per slice.
63+
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
64+
dcn_fsdp_parallelism: 1
65+
dcn_tensor_parallelism: 1
66+
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
67+
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
68+
ici_tensor_parallelism: 1
69+
70+
# Dataset
71+
# Replace with dataset path or train_data_dir. One has to be set.
72+
dataset_name: 'lambdalabs/pokemon-blip-captions'
73+
train_data_dir: ''
74+
dataset_config_name: ''
75+
cache_dir: ''
76+
image_column: 'image'
77+
caption_column: 'text'
78+
resolution: 512
79+
center_crop: False
80+
random_flip: False
81+
# If cache_latents_text_encoder_outputs is True
82+
# the num_proc is set to 1
83+
tokenize_captions_num_proc: 4
84+
transform_images_num_proc: 4
85+
reuse_example_batch: False
86+
enable_data_shuffling: True
87+
88+
# Prepare image latents and text encoder outputs
89+
# during dataset creation to reduce memory consumption.
90+
cache_latents_text_encoder_outputs: True
91+
92+
93+
# Training loop
94+
learning_rate: 1.e-7
95+
scale_lr: False
96+
max_train_samples: -1
97+
# max_train_steps takes priority over num_train_epochs.
98+
max_train_steps: 800
99+
seed: 0
100+
output_dir: 'sd-model-finetuned'
101+
tensorboard_dir: 'gs://shahrokhi-maxdiffusion-v5'
102+
per_device_batch_size: 1
103+
104+
cosine_learning_rate_final_fraction: 0.1
105+
warmup_steps_fraction: 0.1
106+
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
107+
108+
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
109+
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
110+
111+
# AdamW optimizer parameters
112+
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
113+
adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients.
114+
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
115+
adam_weight_decay: 1.e-2 # AdamW Weight decay
116+
117+
max_grad_norm: 1.0
118+
119+
enable_profiler: True
120+
121+
# Generation parameters
122+
prompt: "A magical castle in the middle of a forest, artistic drawing"
123+
negative_prompt: "purple, red"
124+
guidance_scale: 7.5
125+
num_inference_steps: 30
126+
seed: 47

0 commit comments

Comments
 (0)