Skip to content

Commit b1b235a

Browse files
sdxl docker inference
1 parent b571906 commit b1b235a

File tree

6 files changed

+346
-0
lines changed

6 files changed

+346
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
FROM python:3.10
2+
3+
USER root
4+
5+
ENV infer_port=8080
6+
7+
# Install libraries
8+
ENV PIP_ROOT_USER_ACTION=ignore
9+
RUN python3 -m pip install --upgrade pip
10+
RUN pip install jax[tpu]>=0.4.16 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
11+
RUN pip install fastapi
12+
RUN pip install uvicorn
13+
14+
# Copy LICENSE file
15+
RUN apt-get update && apt-get install wget
16+
17+
# Install diffusers from main branch source code with a pinned commit.
18+
RUN git clone https://github.com/google/maxdiffusion && \
19+
cd maxdiffusion && \
20+
pip install -r requirements.txt && \
21+
git checkout inf_mlperf && \
22+
pip3 install -e .
23+
24+
WORKDIR maxdiffusion
25+
26+
27+
# Copy model artifacts.
28+
COPY docker/sdxl_inference/model_loader.py .
29+
30+
RUN JAX_PLATFORMS='' python3 model_loader.py
31+
32+
COPY docker/sdxl_inference/entrypoint.sh .
33+
COPY docker/sdxl_inference/handler.py .
34+
COPY docker/sdxl_inference/latents.npy .
35+
36+
EXPOSE ${infer_port}
37+
38+
39+
CMD ["./entrypoint.sh"]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
export PORT=$AIP_HTTP_PORT
3+
uvicorn handler:app --proxy-headers --host 0.0.0.0 --port $PORT

docker/sdxl_inference/handler.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import functools
19+
from absl import app
20+
import io
21+
import base64
22+
from PIL import Image
23+
24+
from fastapi import FastAPI, Request
25+
26+
import numpy as np
27+
import jax
28+
import jax.numpy as jnp
29+
from jax.sharding import Mesh
30+
from jax.sharding import PartitionSpec as P
31+
from jax.experimental.compilation_cache import compilation_cache as cc
32+
from flax.linen import partitioning as nn_partitioning
33+
from jax.sharding import PositionalSharding
34+
35+
from maxdiffusion import (
36+
FlaxStableDiffusionXLPipeline
37+
)
38+
39+
40+
from maxdiffusion import pyconfig
41+
from maxdiffusion.image_processor import VaeImageProcessor
42+
from maxdiffusion.max_utils import (
43+
create_device_mesh,
44+
get_dtype,
45+
get_states,
46+
device_put_replicated
47+
)
48+
49+
cc.initialize_cache(os.path.expanduser("~/jax_cache"))
50+
51+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
52+
pyconfig.initialize([None,os.path.join(THIS_DIR,'src/maxdiffusion','configs','base_xl.yml'),
53+
"pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
54+
"revision=refs/pr/95","dtype=bfloat16","resolution=1024",
55+
"prompt=A magical castle in the middle of a forest, artistic drawing",
56+
"negative_prompt=purple, red","guidance_scale=9",
57+
"num_inference_steps=20","seed=47","per_device_batch_size=1",
58+
"run_name=sdxl-inference-test","split_head_dim=True"])
59+
60+
config = pyconfig.config
61+
62+
rng = jax.random.PRNGKey(config.seed)
63+
devices_array = create_device_mesh(config)
64+
mesh = Mesh(devices_array, config.mesh_axes)
65+
66+
batch_size = config.per_device_batch_size * jax.device_count()
67+
_latents = np.load(f"{THIS_DIR}/latents.npy")
68+
_latents = jnp.array([_latents[0]] * batch_size)
69+
weight_dtype= get_dtype(config)
70+
71+
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
72+
config.pretrained_model_name_or_path,
73+
revision=config.revision,
74+
dtype=weight_dtype,
75+
split_head_dim=config.split_head_dim
76+
)
77+
78+
scheduler_state = params.pop("scheduler")
79+
params = jax.tree_util.tree_map(lambda x: x.astype(weight_dtype), params)
80+
params["scheduler"] = scheduler_state
81+
82+
data_sharding = jax.sharding.NamedSharding(mesh,P(*config.data_sharding))
83+
84+
sharding = PositionalSharding(devices_array).replicate()
85+
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=sharding)
86+
params["text_encoder"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder"])
87+
params["text_encoder_2"] = jax.tree_util.tree_map(partial_device_put_replicated, params["text_encoder_2"])
88+
89+
unet_state, unet_state_mesh_shardings, vae_state, vae_state_mesh_shardings = get_states(mesh, None, rng, config, pipeline, params["unet"], params["vae"], training=False)
90+
del params["vae"]
91+
del params["unet"]
92+
93+
def image_to_base64(image: Image.Image) -> str:
94+
"""Convert a PIL image to a base64 string."""
95+
buffer = io.BytesIO()
96+
image.save(buffer, format="JPEG")
97+
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
98+
return image_str
99+
100+
def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale):
101+
latents, scheduler_state, state = args
102+
latents_input = jnp.concatenate([latents] * 2)
103+
104+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
105+
timestep = jnp.broadcast_to(t, latents_input.shape[0])
106+
107+
latents_input = pipeline.scheduler.scale_model_input(scheduler_state, latents_input, t)
108+
noise_pred = model.apply(
109+
{"params" : state.params},
110+
jnp.array(latents_input),
111+
jnp.array(timestep, dtype=jnp.int32),
112+
encoder_hidden_states=prompt_embeds,
113+
added_cond_kwargs=added_cond_kwargs
114+
).sample
115+
116+
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
117+
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
118+
119+
latents, scheduler_state = pipeline.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
120+
121+
return latents, scheduler_state, state
122+
123+
def get_add_time_ids(original_size, crops_coords_top_left, target_size, bs, dtype):
124+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
125+
add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype)
126+
return add_time_ids
127+
128+
def get_embeddings(prompt_ids, pipeline, params):
129+
te_1_inputs = prompt_ids[:, 0, :]
130+
te_2_inputs = prompt_ids[:, 1, :]
131+
132+
prompt_embeds = pipeline.text_encoder(
133+
te_1_inputs, params=params["text_encoder"], output_hidden_states=True
134+
)
135+
prompt_embeds = prompt_embeds["hidden_states"][-2]
136+
prompt_embeds_2_out = pipeline.text_encoder_2(
137+
te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True
138+
)
139+
prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2]
140+
text_embeds = prompt_embeds_2_out["text_embeds"]
141+
prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
142+
return prompt_embeds, text_embeds
143+
144+
def tokenize(prompt, pipeline):
145+
inputs = []
146+
for _tokenizer in [pipeline.tokenizer, pipeline.tokenizer_2]:
147+
text_inputs = _tokenizer(
148+
prompt,
149+
padding="max_length",
150+
max_length=_tokenizer.model_max_length,
151+
truncation=True,
152+
return_tensors="np"
153+
)
154+
inputs.append(text_inputs.input_ids)
155+
inputs = jnp.stack(inputs,axis=1)
156+
return inputs
157+
158+
def get_prompt_ids(prompts, batch_size):
159+
if len(prompts) != batch_size:
160+
prompts += [prompts[0]] * (batch_size - len(prompts))
161+
prompt_ids = tokenize(prompts, pipeline)
162+
return prompt_ids
163+
164+
def get_unet_inputs(rng, config, batch_size, pipeline, params, prompt_ids):
165+
# pad with first element if it doesn't fill batch_size
166+
# if len(prompts) != batch_size:
167+
# prompts = [prompts[0]] * (batch_size - len(prompts))
168+
# prompt_ids = tokenize(prompt_ids, pipeline)
169+
negative_prompt_ids = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] * batch_size
170+
negative_prompt_ids = tokenize(negative_prompt_ids, pipeline)
171+
guidance_scale = config.guidance_scale
172+
num_inference_steps = config.num_inference_steps
173+
height = config.resolution
174+
width = config.resolution
175+
prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, params)
176+
batch_size = prompt_embeds.shape[0]
177+
negative_prompt_embeds, negative_pooled_embeds = get_embeddings(negative_prompt_ids, pipeline, params)
178+
add_time_ids = get_add_time_ids(
179+
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
180+
)
181+
182+
prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0)
183+
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
184+
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
185+
# Ensure model output will be `float32` before going into the scheduler
186+
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
187+
188+
scheduler_state = pipeline.scheduler.set_timesteps(
189+
params["scheduler"],
190+
num_inference_steps=num_inference_steps,
191+
shape=_latents.shape
192+
)
193+
194+
latents = _latents * scheduler_state.init_noise_sigma
195+
196+
added_cond_kwargs = {"text_embeds" : add_text_embeds, "time_ids" : add_time_ids}
197+
latents = jax.device_put(latents, data_sharding)
198+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
199+
guidance_scale = jax.device_put(guidance_scale, PositionalSharding(devices_array).replicate())
200+
added_cond_kwargs['text_embeds'] = jax.device_put(added_cond_kwargs['text_embeds'], data_sharding)
201+
added_cond_kwargs['time_ids'] = jax.device_put(added_cond_kwargs['time_ids'], data_sharding)
202+
203+
return latents, prompt_embeds, added_cond_kwargs, guidance_scale, scheduler_state
204+
205+
206+
def vae_decode(latents, state, pipeline):
207+
latents = 1 / pipeline.vae.config.scaling_factor * latents
208+
image = pipeline.vae.apply(
209+
{"params" : state.params},
210+
latents,
211+
method=pipeline.vae.decode
212+
).sample
213+
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
214+
return image
215+
216+
def run_inference(unet_state, vae_state, params, prompt_ids, rng, config, batch_size, pipeline):
217+
218+
(latents,
219+
prompt_embeds,
220+
added_cond_kwargs,
221+
guidance_scale,
222+
scheduler_state) = get_unet_inputs(rng, config, batch_size, pipeline, params, prompt_ids)
223+
224+
loop_body_p = functools.partial(loop_body, model=pipeline.unet,
225+
pipeline=pipeline,
226+
added_cond_kwargs=added_cond_kwargs,
227+
prompt_embeds=prompt_embeds,
228+
guidance_scale=guidance_scale)
229+
vae_decode_p = functools.partial(vae_decode, pipeline=pipeline)
230+
231+
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
232+
latents, _, _ = jax.lax.fori_loop(0, config.num_inference_steps,
233+
loop_body_p, (latents, scheduler_state, unet_state))
234+
images = vae_decode_p(latents, vae_state)
235+
return images
236+
237+
p_run_inference = jax.jit(
238+
functools.partial(run_inference, rng=rng, config=config, batch_size=batch_size, pipeline=pipeline),
239+
in_shardings=(unet_state_mesh_shardings, vae_state_mesh_shardings, None, None),
240+
out_shardings=None
241+
)
242+
243+
prompt_ids = get_prompt_ids([config.prompt], batch_size)
244+
images = p_run_inference(unet_state, vae_state, params, prompt_ids)
245+
246+
app = FastAPI()
247+
248+
@app.get("/health", status_code=200)
249+
def health():
250+
return {}
251+
252+
@app.post("/predict")
253+
async def predict(request: Request):
254+
body = await request.json()
255+
instances = body["instances"]
256+
retval = []
257+
for instance in instances:
258+
prompt = instance["prompt"] # list
259+
prompt_ids = get_prompt_ids(prompt, batch_size)
260+
images = p_run_inference(unet_state, vae_state, params, prompt_ids)
261+
images = VaeImageProcessor.numpy_to_pil(np.array(images))
262+
263+
retval_images = []
264+
for image in images:
265+
retval_images.append(image_to_base64(image))
266+
267+
retval.append({"instance" : instance, "images" : retval_images})
268+
return {"predictions" : retval}

docker/sdxl_inference/latents.npy

256 KB
Binary file not shown.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from maxdiffusion import FlaxStableDiffusionXLPipeline
2+
3+
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
4+
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
5+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import requests
2+
import time
3+
r = requests.get("http://127.0.0.1:8080/health")
4+
print(r.status_code, r.reason)
5+
iters = 10
6+
for _ in range(iters):
7+
s = time.time()
8+
r = requests.post("http://127.0.0.1:8080/predict",
9+
json={"instances": [
10+
{
11+
"prompt" : ["a dog walking a cat"],
12+
"query_id" : ["1"]
13+
},
14+
{
15+
"prompt" : ["a dog walking a cat"],
16+
"query_id" : ["1"]
17+
},
18+
{
19+
"prompt" : ["a dog walking a cat"],
20+
"query_id" : ["1"]
21+
}
22+
]
23+
},
24+
headers={"Content-Type": "application/json"},
25+
)
26+
print(r.status_code, r.reason)
27+
print("request time: ", (time.time() - s))
28+
29+
with open("response.json", "w") as f:
30+
f.write(r.text)
31+
print("total time: ", (time.time() - s))

0 commit comments

Comments
 (0)