|
| 1 | +""" |
| 2 | + Copyright 2024 Google LLC |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | + """ |
| 16 | + |
| 17 | +import os |
| 18 | +import functools |
| 19 | +from absl import app |
| 20 | +import io |
| 21 | +import base64 |
| 22 | +from PIL import Image |
| 23 | + |
| 24 | +from fastapi import FastAPI, Request |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import jax |
| 28 | +import jax.numpy as jnp |
| 29 | +from jax.sharding import Mesh |
| 30 | +from jax.sharding import PartitionSpec as P |
| 31 | +from jax.experimental.compilation_cache import compilation_cache as cc |
| 32 | +from flax.linen import partitioning as nn_partitioning |
| 33 | +from jax.sharding import PositionalSharding |
| 34 | + |
| 35 | +from maxdiffusion import ( |
| 36 | + FlaxStableDiffusionXLPipeline |
| 37 | +) |
| 38 | + |
| 39 | + |
| 40 | +from maxdiffusion import pyconfig |
| 41 | +from maxdiffusion.image_processor import VaeImageProcessor |
| 42 | +from maxdiffusion.max_utils import ( |
| 43 | + create_device_mesh, |
| 44 | + get_dtype, |
| 45 | + get_states, |
| 46 | + device_put_replicated |
| 47 | +) |
| 48 | + |
| 49 | +cc.initialize_cache(os.path.expanduser("~/jax_cache")) |
| 50 | + |
| 51 | +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) |
| 52 | +pyconfig.initialize([None,os.path.join(THIS_DIR,'src/maxdiffusion','configs','base_xl.yml'), |
| 53 | + "pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0", |
| 54 | + "revision=refs/pr/95","dtype=bfloat16","resolution=1024", |
| 55 | + "prompt=A magical castle in the middle of a forest, artistic drawing", |
| 56 | + "negative_prompt=purple, red","guidance_scale=9", |
| 57 | + "num_inference_steps=20","seed=47","per_device_batch_size=1", |
| 58 | + "run_name=sdxl-inference-test","split_head_dim=True"]) |
| 59 | + |
| 60 | +config = pyconfig.config |
| 61 | + |
| 62 | +rng = jax.random.PRNGKey(config.seed) |
| 63 | +devices_array = create_device_mesh(config) |
| 64 | +mesh = Mesh(devices_array, config.mesh_axes) |
| 65 | + |
| 66 | +batch_size = config.per_device_batch_size * jax.device_count() |
| 67 | +_latents = np.load(f"{THIS_DIR}/latents.npy") |
| 68 | +_latents = jnp.array([_latents[0]] * batch_size) |
| 69 | +weight_dtype= get_dtype(config) |
| 70 | + |
| 71 | +pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( |
| 72 | + config.pretrained_model_name_or_path, |
| 73 | + revision=config.revision, |
| 74 | + dtype=weight_dtype, |
| 75 | + split_head_dim=config.split_head_dim |
| 76 | +) |
| 77 | + |
| 78 | +scheduler_state = params.pop("scheduler") |
| 79 | +params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params) |
| 80 | +params["scheduler"] = scheduler_state |
| 81 | + |
| 82 | +data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding)) |
| 83 | + |
| 84 | +sharding = PositionalSharding(devices_array).replicate() |
| 85 | +partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding) |
| 86 | +params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"]) |
| 87 | +params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"]) |
| 88 | + |
| 89 | +unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, params["unet"], params["vae"], training=False) |
| 90 | +del params["vae"] |
| 91 | +del params["unet"] |
| 92 | + |
| 93 | +def image_to_base64(image: Image.Image) -> str: |
| 94 | + """Convert a PIL image to a base64 string.""" |
| 95 | + buffer = io.BytesIO() |
| 96 | + image.save(buffer, format="JPEG") |
| 97 | + image_str = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| 98 | + return image_str |
| 99 | + |
| 100 | +def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale): |
| 101 | + latents, scheduler_state, state = args |
| 102 | + latents_input = jnp.concatenate([latents] * 2) |
| 103 | + |
| 104 | + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] |
| 105 | + timestep = jnp.broadcast_to(t, latents_input.shape[0]) |
| 106 | + |
| 107 | + latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t) |
| 108 | + noise_pred = model.apply( |
| 109 | + {"params" : state.params}, |
| 110 | + jnp.array(latents_input), |
| 111 | + jnp.array(timestep, dtype=jnp.int32), |
| 112 | + encoder_hidden_states=prompt_embeds, |
| 113 | + added_cond_kwargs=added_cond_kwargs |
| 114 | + ).sample |
| 115 | + |
| 116 | + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) |
| 117 | + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) |
| 118 | + |
| 119 | + latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() |
| 120 | + |
| 121 | + return latents, scheduler_state, state |
| 122 | + |
| 123 | +def get_add_time_ids(original_size, crops_coords_top_left, target_size, bs, dtype): |
| 124 | + add_time_ids = list(original_size + crops_coords_top_left + target_size) |
| 125 | + add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype) |
| 126 | + return add_time_ids |
| 127 | + |
| 128 | +def get_embeddings(prompt_ids, pipeline, params): |
| 129 | + te_1_inputs = prompt_ids[:, 0, :] |
| 130 | + te_2_inputs = prompt_ids[:, 1, :] |
| 131 | + |
| 132 | + prompt_embeds = pipeline.text_encoder( |
| 133 | + te_1_inputs, params=params["text_encoder"], output_hidden_states=True |
| 134 | + ) |
| 135 | + prompt_embeds = prompt_embeds["hidden_states"][-2] |
| 136 | + prompt_embeds_2_out = pipeline.text_encoder_2( |
| 137 | + te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True |
| 138 | + ) |
| 139 | + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] |
| 140 | + text_embeds = prompt_embeds_2_out["text_embeds"] |
| 141 | + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) |
| 142 | + return prompt_embeds, text_embeds |
| 143 | + |
| 144 | +def tokenize(prompt, pipeline): |
| 145 | + inputs = [] |
| 146 | + for _tokenizer in [pipeline.tokenizer, pipeline.tokenizer_2]: |
| 147 | + text_inputs = _tokenizer( |
| 148 | + prompt, |
| 149 | + padding="max_length", |
| 150 | + max_length=_tokenizer.model_max_length, |
| 151 | + truncation=True, |
| 152 | + return_tensors="np" |
| 153 | + ) |
| 154 | + inputs.append(text_inputs.input_ids) |
| 155 | + inputs = jnp.stack(inputs,axis=1) |
| 156 | + return inputs |
| 157 | + |
| 158 | +def get_prompt_ids(prompts, batch_size): |
| 159 | + if len(prompts) != batch_size: |
| 160 | + prompts += [prompts[0]] * (batch_size - len(prompts)) |
| 161 | + prompt_ids = tokenize(prompts, pipeline) |
| 162 | + return prompt_ids |
| 163 | + |
| 164 | +def get_unet_inputs(rng, config, batch_size, pipeline, params, prompt_ids): |
| 165 | + # pad with first element if it doesn't fill batch_size |
| 166 | + # if len(prompts) != batch_size: |
| 167 | + # prompts = [prompts[0]] * (batch_size - len(prompts)) |
| 168 | + # prompt_ids = tokenize(prompt_ids, pipeline) |
| 169 | + negative_prompt_ids = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] * batch_size |
| 170 | + negative_prompt_ids = tokenize(negative_prompt_ids, pipeline) |
| 171 | + guidance_scale = config.guidance_scale |
| 172 | + num_inference_steps = config.num_inference_steps |
| 173 | + height = config.resolution |
| 174 | + width = config.resolution |
| 175 | + prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, params) |
| 176 | + batch_size = prompt_embeds.shape[0] |
| 177 | + negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, params) |
| 178 | + add_time_ids = get_add_time_ids( |
| 179 | + (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype |
| 180 | + ) |
| 181 | + |
| 182 | + prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0) |
| 183 | + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) |
| 184 | + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) |
| 185 | + # Ensure model output will be `float32` before going into the scheduler |
| 186 | + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) |
| 187 | + |
| 188 | + scheduler_state = pipeline.scheduler.set_timesteps( |
| 189 | + params["scheduler"], |
| 190 | + num_inference_steps=num_inference_steps, |
| 191 | + shape=_latents.shape |
| 192 | + ) |
| 193 | + |
| 194 | + latents = _latents * scheduler_state.init_noise_sigma |
| 195 | + |
| 196 | + added_cond_kwargs = {"text_embeds" : add_text_embeds, "time_ids" : add_time_ids} |
| 197 | + latents = jax.device_put(latents, data_sharding) |
| 198 | + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) |
| 199 | + guidance_scale = jax.device_put(guidance_scale, PositionalSharding(devices_array).replicate()) |
| 200 | + added_cond_kwargs['text_embeds'] = jax.device_put(added_cond_kwargs['text_embeds'], data_sharding) |
| 201 | + added_cond_kwargs['time_ids'] = jax.device_put(added_cond_kwargs['time_ids'], data_sharding) |
| 202 | + |
| 203 | + return latents, prompt_embeds, added_cond_kwargs, guidance_scale, scheduler_state |
| 204 | + |
| 205 | + |
| 206 | +def vae_decode(latents, state, pipeline): |
| 207 | + latents = 1 / pipeline.vae.config.scaling_factor * latents |
| 208 | + image = pipeline.vae.apply( |
| 209 | + {"params" : state.params}, |
| 210 | + latents, |
| 211 | + method=pipeline.vae.decode |
| 212 | + ).sample |
| 213 | + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) |
| 214 | + return image |
| 215 | + |
| 216 | +def run_inference(unet_state, vae_state, params, prompt_ids, rng, config, batch_size, pipeline): |
| 217 | + |
| 218 | + (latents, |
| 219 | + prompt_embeds, |
| 220 | + added_cond_kwargs, |
| 221 | + guidance_scale, |
| 222 | + scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params, prompt_ids) |
| 223 | + |
| 224 | + loop_body_p = functools.partial(loop_body, model=pipeline.unet, |
| 225 | + pipeline=pipeline, |
| 226 | + added_cond_kwargs=added_cond_kwargs, |
| 227 | + prompt_embeds=prompt_embeds, |
| 228 | + guidance_scale=guidance_scale) |
| 229 | + vae_decode_p = functools.partial(vae_decode, pipeline=pipeline) |
| 230 | + |
| 231 | + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): |
| 232 | + latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps, |
| 233 | + loop_body_p, (latents, scheduler_state, unet_state)) |
| 234 | + images = vae_decode_p(latents, vae_state) |
| 235 | + return images |
| 236 | + |
| 237 | +p_run_inference = jax.jit( |
| 238 | + functools.partial(run_inference, rng=rng, config=config, batch_size=batch_size, pipeline=pipeline), |
| 239 | + in_shardings=(unet_state_mesh_shardings, vae_state_mesh_shardings, None, None), |
| 240 | + out_shardings=None |
| 241 | +) |
| 242 | + |
| 243 | +prompt_ids = get_prompt_ids([config.prompt], batch_size) |
| 244 | +images = p_run_inference(unet_state, vae_state, params, prompt_ids) |
| 245 | + |
| 246 | +app = FastAPI() |
| 247 | + |
| 248 | +@app.get("/health", status_code=200) |
| 249 | +def health(): |
| 250 | + return {} |
| 251 | + |
| 252 | +@app.post("/predict") |
| 253 | +async def predict(request: Request): |
| 254 | + body = await request.json() |
| 255 | + instances = body["instances"] |
| 256 | + retval = [] |
| 257 | + for instance in instances: |
| 258 | + prompt = instance["prompt"] # list |
| 259 | + prompt_ids = get_prompt_ids(prompt, batch_size) |
| 260 | + images = p_run_inference(unet_state, vae_state, params, prompt_ids) |
| 261 | + images = VaeImageProcessor.numpy_to_pil(np.array(images)) |
| 262 | + |
| 263 | + retval_images = [] |
| 264 | + for image in images: |
| 265 | + retval_images.append(image_to_base64(image)) |
| 266 | + |
| 267 | + retval.append({"instance" : instance, "images" : retval_images}) |
| 268 | + return {"predictions" : retval} |
0 commit comments