Skip to content

Commit 79f08a1

Browse files
linter and tests
1 parent d899be4 commit 79f08a1

File tree

4 files changed

+28
-27
lines changed

4 files changed

+28
-27
lines changed

src/maxdiffusion/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,4 +196,4 @@ def main(argv: Sequence[str]) -> None:
196196
run(pyconfig.config)
197197

198198
if __name__ == "__main__":
199-
app.run(main)
199+
app.run(main)

src/maxdiffusion/models/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,8 @@ def train(config):
147147
weight_dtype = max_utils.get_dtype(config)
148148
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
149149
config.pretrained_model_name_or_path,revision=config.revision, dtype=weight_dtype,
150-
safety_checker=None, feature_extractor=None, from_pt=config.from_pt
150+
safety_checker=None, feature_extractor=None, from_pt=config.from_pt,
151+
split_head_dim=config.split_head_dim
151152
)
152153

153154
noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(config.pretrained_model_name_or_path,
@@ -385,4 +386,4 @@ def main(argv: Sequence[str]) -> None:
385386
validate_train_config(config)
386387
train(config)
387388
if __name__ == "__main__":
388-
app.run(main)
389+
app.run(main)
3.22 KB
Loading

src/maxdiffusion/tests/train_smoke_test.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,30 +65,30 @@ def test_sd21_config(self):
6565

6666
cleanup(output_dir)
6767

68-
# def test_sd_2_base_config(self):
69-
# output_dir="train-smoke-test"
70-
# train_main([None,os.path.join(THIS_DIR,'..','configs','base_2_base.yml'),
71-
# "run_name=sd2_base_smoke_test","max_train_steps=21","dataset_name=lambdalabs/pokemon-blip-captions",
72-
# "base_output_directory=gs://maxdiffusion-tests", f"output_dir={output_dir}"])
73-
74-
# img_url = os.path.join(THIS_DIR,'images','test_2_base.png')
75-
# base_image = np.array(Image.open(img_url)).astype(np.uint8)
76-
77-
# pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base2_base_inference.yml'),
78-
# f"pretrained_model_name_or_path={output_dir}",
79-
# "prompt=A magical castle in the middle of a forest, artistic drawing",
80-
# "negative_prompt=purple, red","guidance_scale=7.5",
81-
# "num_inference_steps=30","seed=47"])
82-
83-
# images = generate_run(pyconfig.config)
84-
# test_image = np.array(images[1]).astype(np.uint8)
85-
# ssim_compare = ssim(base_image, test_image,
86-
# multichannel=True, channel_axis=-1, data_range=255
87-
# )
88-
# assert base_image.shape == test_image.shape
89-
# assert ssim_compare >=0.70
90-
91-
# cleanup(output_dir)
68+
def test_sd_2_base_config(self):
69+
output_dir="train-smoke-test"
70+
train_main([None,os.path.join(THIS_DIR,'..','configs','base_2_base.yml'),
71+
"run_name=sd2_base_smoke_test","max_train_steps=21","dataset_name=lambdalabs/pokemon-blip-captions",
72+
"base_output_directory=gs://maxdiffusion-tests", f"output_dir={output_dir}"])
73+
74+
img_url = os.path.join(THIS_DIR,'images','test_2_base.png')
75+
base_image = np.array(Image.open(img_url)).astype(np.uint8)
76+
77+
pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base_2_base_inference.yml'),
78+
f"pretrained_model_name_or_path={output_dir}",
79+
"prompt=A magical castle in the middle of a forest, artistic drawing",
80+
"negative_prompt=purple, red","guidance_scale=7.5",
81+
"num_inference_steps=30","seed=47"])
82+
83+
images = generate_run(pyconfig.config)
84+
test_image = np.array(images[0]).astype(np.uint8)
85+
ssim_compare = ssim(base_image, test_image,
86+
multichannel=True, channel_axis=-1, data_range=255
87+
)
88+
assert base_image.shape == test_image.shape
89+
assert ssim_compare >=0.70
90+
91+
cleanup(output_dir)
9292

9393
if __name__ == '__main__':
9494
absltest.main()

0 commit comments

Comments
 (0)