@@ -65,30 +65,30 @@ def test_sd21_config(self):
6565
6666 cleanup (output_dir )
6767
68- # def test_sd_2_base_config(self):
69- # output_dir="train-smoke-test"
70- # train_main([None,os.path.join(THIS_DIR,'..','configs','base_2_base.yml'),
71- # "run_name=sd2_base_smoke_test","max_train_steps=21","dataset_name=lambdalabs/pokemon-blip-captions",
72- # "base_output_directory=gs://maxdiffusion-tests", f"output_dir={output_dir}"])
73-
74- # img_url = os.path.join(THIS_DIR,'images','test_2_base.png')
75- # base_image = np.array(Image.open(img_url)).astype(np.uint8)
76-
77- # pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base2_base_inference .yml'),
78- # f"pretrained_model_name_or_path={output_dir}",
79- # "prompt=A magical castle in the middle of a forest, artistic drawing",
80- # "negative_prompt=purple, red","guidance_scale=7.5",
81- # "num_inference_steps=30","seed=47"])
82-
83- # images = generate_run(pyconfig.config)
84- # test_image = np.array(images[1 ]).astype(np.uint8)
85- # ssim_compare = ssim(base_image, test_image,
86- # multichannel=True, channel_axis=-1, data_range=255
87- # )
88- # assert base_image.shape == test_image.shape
89- # assert ssim_compare >=0.70
90-
91- # cleanup(output_dir)
68+ def test_sd_2_base_config (self ):
69+ output_dir = "train-smoke-test"
70+ train_main ([None ,os .path .join (THIS_DIR ,'..' ,'configs' ,'base_2_base.yml' ),
71+ "run_name=sd2_base_smoke_test" ,"max_train_steps=21" ,"dataset_name=lambdalabs/pokemon-blip-captions" ,
72+ "base_output_directory=gs://maxdiffusion-tests" , f"output_dir={ output_dir } " ])
73+
74+ img_url = os .path .join (THIS_DIR ,'images' ,'test_2_base.png' )
75+ base_image = np .array (Image .open (img_url )).astype (np .uint8 )
76+
77+ pyconfig .initialize ([None ,os .path .join (THIS_DIR ,'..' ,'configs' ,'base_2_base_inference .yml' ),
78+ f"pretrained_model_name_or_path={ output_dir } " ,
79+ "prompt=A magical castle in the middle of a forest, artistic drawing" ,
80+ "negative_prompt=purple, red" ,"guidance_scale=7.5" ,
81+ "num_inference_steps=30" ,"seed=47" ])
82+
83+ images = generate_run (pyconfig .config )
84+ test_image = np .array (images [0 ]).astype (np .uint8 )
85+ ssim_compare = ssim (base_image , test_image ,
86+ multichannel = True , channel_axis = - 1 , data_range = 255
87+ )
88+ assert base_image .shape == test_image .shape
89+ assert ssim_compare >= 0.70
90+
91+ cleanup (output_dir )
9292
9393if __name__ == '__main__' :
9494 absltest .main ()
0 commit comments