Skip to content

Commit 9410325

Browse files
Reinstated the CompVis safety checker.
1 parent ef6f356 commit 9410325

File tree

1 file changed

+111
-91
lines changed

1 file changed

+111
-91
lines changed

sdd.py

Lines changed: 111 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
'''
2-
sdd.py: Stable Diffusion daemon
2+
sdd.py: Stable Diffusion daemon. Pre-load the model and serve image prompts via FastAPI.
33
4-
Pre-load the model and serve image prompts via FastAPI.
4+
This fetches SD from Hugging Face, so huggingface-cli login first.
55
6-
Reduces rendering time from 1:20 to about 3.5 seconds on an RTX 2080.
6+
Reduces rendering time about 3.5 seconds for 40 steps on an RTX 2080.
77
'''
88
import random
99

@@ -13,58 +13,67 @@
1313

1414
import torch
1515

16-
from fastapi import FastAPI, Query, HTTPException
16+
from fastapi import FastAPI, Query
1717
from fastapi.responses import StreamingResponse
1818

19-
from transformers import CLIPTextModel, CLIPTokenizer
20-
from transformers import logging
19+
from transformers import CLIPTextModel, CLIPTokenizer, AutoFeatureExtractor, logging
2120
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
21+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
22+
2223
from tqdm.auto import tqdm
2324
from torch import autocast
2425
from PIL import Image
2526

2627
app = FastAPI()
2728

28-
# Every GPU device that can be used for image generation
29-
GPUS = {
30-
"0": {"name": "RTX 2080 Ti", "lock": Lock()},
31-
}
32-
33-
# MODEL = "CompVis/stable-diffusion-v1-4"
3429
MODELS = {
3530
"unet": {
31+
# "name" = "CompVis/stable-diffusion-v1-4",
3632
"name": "runwayml/stable-diffusion-v1-5",
37-
"subfolder": "unet"
33+
"sub": "unet"
3834
},
3935
"vae": {
4036
"name": "stabilityai/sd-vae-ft-ema",
41-
"subfolder": ""
37+
"sub": ""
4238
},
4339
"tokenizer": {
4440
"name": "openai/clip-vit-large-patch14",
45-
"subfolder": ""
41+
"sub": ""
4642
},
4743
"text_encoder": {
4844
"name": "openai/clip-vit-large-patch14",
49-
"subfolder": ""
45+
"sub": ""
46+
},
47+
"safety": {
48+
"name": "CompVis/stable-diffusion-safety-checker",
49+
"sub": ""
5050
}
5151
}
5252

53+
# One lock for each available GPU (only one supported for now)
54+
GPUS = {}
55+
for i in range(torch.cuda.device_count()):
56+
GPUS[i] = Lock()
57+
58+
if not GPUS:
59+
raise RuntimeError("No GPUs detected. Check your config and try again.")
60+
5361
# Supress some unnecessary warnings when loading the CLIPTextModel
5462
logging.set_verbosity_error()
5563

56-
if not torch.cuda.is_available():
57-
raise RuntimeError('No CUDA device available, exiting.')
58-
5964
# Load the autoencoder model which will be used to decode the latents into image space.
60-
vae = AutoencoderKL.from_pretrained(MODELS["vae"]["name"], subfolder=MODELS["vae"]["subfolder"], use_auth_token=True)
65+
vae = AutoencoderKL.from_pretrained(MODELS["vae"]["name"], subfolder=MODELS["vae"]["sub"], use_auth_token=True)
6166

6267
# Load the tokenizer and text encoder to tokenize and encode the text.
63-
tokenizer = CLIPTokenizer.from_pretrained(MODELS["tokenizer"]["name"], subfolder=MODELS["tokenizer"]["subfolder"])
64-
text_encoder = CLIPTextModel.from_pretrained(MODELS["text_encoder"]["name"], subfolder=MODELS["text_encoder"]["subfolder"])
68+
tokenizer = CLIPTokenizer.from_pretrained(MODELS["tokenizer"]["name"], subfolder=MODELS["tokenizer"]["sub"])
69+
text_encoder = CLIPTextModel.from_pretrained(MODELS["text_encoder"]["name"], subfolder=MODELS["text_encoder"]["sub"])
6570

6671
# The UNet model for generating the latents.
67-
unet = UNet2DConditionModel.from_pretrained(MODELS["unet"]["name"], subfolder=MODELS["unet"]["subfolder"], use_auth_token=True)
72+
unet = UNet2DConditionModel.from_pretrained(MODELS["unet"]["name"], subfolder=MODELS["unet"]["sub"], use_auth_token=True)
73+
74+
# The CompVis safety model.
75+
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(MODELS["safety"]["name"], subfolder=MODELS["safety"]["sub"])
76+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODELS["safety"]["name"], subfolder=MODELS["safety"]["sub"])
6877

6978
# The noise scheduler
7079
scheduler = LMSDiscreteScheduler(
@@ -79,76 +88,96 @@
7988
text_encoder = text_encoder.to('cuda')
8089
unet = unet.to('cuda')
8190

91+
def naughty(image):
92+
''' Returns True if naughty bits are detected, else False. '''
93+
safety_checker_input = safety_feature_extractor([image], return_tensors="pt")
94+
_, has_nsfw_concept = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values)
95+
return has_nsfw_concept[0]
96+
8297
def wait_for_gpu():
8398
''' Return the device name of first available GPU. Blocks until one is available and sets the lock. '''
8499
while True:
85100
gpu = random.choice(list(GPUS))
86-
if GPUS[gpu]['lock'].acquire(timeout=1):
101+
if GPUS[gpu].acquire(timeout=0.5):
87102
return gpu
88103

89104
def generate_image(prompt, seed, steps, width=512, height=512, guidance=7.5):
90-
''' Generate an image. Returns a FastAPI StreamingResponse. '''
91-
generator = torch.manual_seed(seed)
92-
batch_size = 1
93-
94-
# Prep text
95-
text_input = tokenizer(
96-
[prompt],
97-
padding="max_length",
98-
max_length=tokenizer.model_max_length,
99-
truncation=True,
100-
return_tensors="pt"
101-
)
102-
with torch.no_grad():
103-
text_embeddings = text_encoder(text_input.input_ids.to('cuda'))[0]
104-
max_length = text_input.input_ids.shape[-1]
105-
uncond_input = tokenizer(
106-
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
107-
)
108-
with torch.no_grad():
109-
uncond_embeddings = text_encoder(uncond_input.input_ids.to('cuda'))[0]
110-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # pylint: disable=no-member
111-
112-
# Prep Scheduler
113-
scheduler.set_timesteps(steps)
114-
115-
# Prep latents
116-
latents = torch.randn( # pylint: disable=no-member
117-
(batch_size, unet.in_channels, height // 8, width // 8),
118-
generator=generator,
119-
)
120-
latents = latents.to('cuda')
121-
latents = latents * scheduler.init_noise_sigma
105+
''' Generate and return an image array using the first available GPU '''
122106

123-
# Loop
124-
with autocast("cuda"):
125-
for _, ts in tqdm(enumerate(scheduler.timesteps)):
126-
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
127-
latent_model_input = scheduler.scale_model_input(torch.cat([latents] * 2), ts) # pylint: disable=no-member
107+
gpu = wait_for_gpu()
108+
try:
109+
# Prep text
110+
text_input = tokenizer(
111+
[prompt],
112+
padding="max_length",
113+
max_length=tokenizer.model_max_length,
114+
truncation=True,
115+
return_tensors="pt"
116+
)
117+
with torch.no_grad():
118+
text_embeddings = text_encoder(text_input.input_ids.to('cuda'))[0]
119+
max_length = text_input.input_ids.shape[-1]
120+
uncond_input = tokenizer(
121+
[""], padding="max_length", max_length=max_length, return_tensors="pt"
122+
)
123+
with torch.no_grad():
124+
uncond_embeddings = text_encoder(uncond_input.input_ids.to('cuda'))[0]
125+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # pylint: disable=no-member
128126

129-
# predict the noise residual
130-
with torch.no_grad():
131-
noise_pred = unet(latent_model_input, ts, encoder_hidden_states=text_embeddings).sample
127+
# Prep Scheduler
128+
scheduler.set_timesteps(steps)
132129

133-
# perform guidance
134-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
135-
noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
130+
# Prep latents
131+
latents = torch.randn( # pylint: disable=no-member
132+
(1, unet.in_channels, height // 8, width // 8),
133+
generator=torch.manual_seed(seed),
134+
)
135+
latents = latents.to('cuda')
136+
latents = latents * scheduler.init_noise_sigma
137+
138+
# Loop
139+
with autocast("cuda"):
140+
for _, ts in tqdm(enumerate(scheduler.timesteps)):
141+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
142+
latent_model_input = scheduler.scale_model_input(torch.cat([latents] * 2), ts) # pylint: disable=no-member
143+
144+
# predict the noise residual
145+
with torch.no_grad():
146+
noise_pred = unet(latent_model_input, ts, encoder_hidden_states=text_embeddings).sample
147+
148+
# perform guidance
149+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
150+
noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
151+
152+
# compute the previous noisy sample x_t -> x_t-1
153+
# latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
154+
latents = scheduler.step(noise_pred, ts, latents).prev_sample
155+
156+
# scale and decode the image latents with vae
157+
latents = 1 / 0.18215 * latents
158+
with torch.no_grad():
159+
image = vae.decode(latents).sample
160+
161+
# Display
162+
image = (image / 2 + 0.5).clamp(0, 1)
163+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
164+
images = (image * 255).round().astype("uint8")
165+
166+
return images[0]
167+
finally:
168+
GPUS[gpu].release()
136169

137-
# compute the previous noisy sample x_t -> x_t-1
138-
# latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
139-
latents = scheduler.step(noise_pred, ts, latents).prev_sample
170+
def safe_generate_image(prompt, seed, steps, width=512, height=512, guidance=7.5, nsfw=False):
171+
''' Generate an image and check NSFW. Returns a FastAPI StreamingResponse. '''
140172

141-
# scale and decode the image latents with vae
142-
latents = 1 / 0.18215 * latents
143-
with torch.no_grad():
144-
image = vae.decode(latents).sample
173+
image = generate_image(prompt, seed, steps, width, height, guidance)
145174

146-
# Display
147-
image = (image / 2 + 0.5).clamp(0, 1)
148-
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
149-
images = (image * 255).round().astype("uint8")
175+
if not nsfw and naughty(image):
176+
print("🍆 detected!!!1!")
177+
prompt = "An adorable teddy bear running through a grassy field, early morning volumetric lighting"
178+
image = generate_image(prompt, seed, steps, width, height, guidance)
150179

151-
out = Image.fromarray(images[0])
180+
out = Image.fromarray(image)
152181

153182
# Set the EXIF data. See PIL.ExifTags.TAGS to map numbers to names.
154183
exif = out.getexif()
@@ -158,7 +187,6 @@ def generate_image(prompt, seed, steps, width=512, height=512, guidance=7.5):
158187

159188
buf = BytesIO()
160189
out.save(buf, format="JPEG", quality=85, exif=exif)
161-
162190
buf.seek(0)
163191

164192
return StreamingResponse(buf, media_type="image/jpeg", headers={
@@ -177,22 +205,14 @@ async def generate(
177205
steps: Optional[int] = Query(40),
178206
width: Optional[int] = Query(512),
179207
height: Optional[int] = Query(512),
208+
guidance: Optional[float] = Query(7.5),
209+
nsfw: Optional[bool] = Query(False),
180210
):
181211
''' Generate an image with Stable Diffusion '''
182212

183-
if width * height > 287744:
184-
raise HTTPException(
185-
status_code=422,
186-
detail='Out of GPU memory. Total width * height must be < 287744 pixels.'
187-
)
188-
189213
if seed < 0:
190214
seed = random.randint(0,2**64-1)
191215

192216
prompt = prompt.strip().replace('\n', ' ')
193217

194-
gpu = wait_for_gpu()
195-
try:
196-
return generate_image(prompt, seed, steps, width, height)
197-
finally:
198-
GPUS[gpu]['lock'].release()
218+
return safe_generate_image(prompt, seed, steps, width, height, guidance, nsfw)

0 commit comments

Comments
 (0)