11'''
2- sdd.py: Stable Diffusion daemon
2+ sdd.py: Stable Diffusion daemon. Pre-load the model and serve image prompts via FastAPI.
33
4- Pre-load the model and serve image prompts via FastAPI .
4+ This fetches SD from Hugging Face, so huggingface-cli login first .
55
6- Reduces rendering time from 1:20 to about 3.5 seconds on an RTX 2080.
6+ Reduces rendering time about 3.5 seconds for 40 steps on an RTX 2080.
77'''
88import random
99
1313
1414import torch
1515
16- from fastapi import FastAPI , Query , HTTPException
16+ from fastapi import FastAPI , Query
1717from fastapi .responses import StreamingResponse
1818
19- from transformers import CLIPTextModel , CLIPTokenizer
20- from transformers import logging
19+ from transformers import CLIPTextModel , CLIPTokenizer , AutoFeatureExtractor , logging
2120from diffusers import AutoencoderKL , UNet2DConditionModel , LMSDiscreteScheduler
21+ from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
22+
2223from tqdm .auto import tqdm
2324from torch import autocast
2425from PIL import Image
2526
2627app = FastAPI ()
2728
28- # Every GPU device that can be used for image generation
29- GPUS = {
30- "0" : {"name" : "RTX 2080 Ti" , "lock" : Lock ()},
31- }
32-
33- # MODEL = "CompVis/stable-diffusion-v1-4"
3429MODELS = {
3530 "unet" : {
31+ # "name" = "CompVis/stable-diffusion-v1-4",
3632 "name" : "runwayml/stable-diffusion-v1-5" ,
37- "subfolder " : "unet"
33+ "sub " : "unet"
3834 },
3935 "vae" : {
4036 "name" : "stabilityai/sd-vae-ft-ema" ,
41- "subfolder " : ""
37+ "sub " : ""
4238 },
4339 "tokenizer" : {
4440 "name" : "openai/clip-vit-large-patch14" ,
45- "subfolder " : ""
41+ "sub " : ""
4642 },
4743 "text_encoder" : {
4844 "name" : "openai/clip-vit-large-patch14" ,
49- "subfolder" : ""
45+ "sub" : ""
46+ },
47+ "safety" : {
48+ "name" : "CompVis/stable-diffusion-safety-checker" ,
49+ "sub" : ""
5050 }
5151}
5252
53+ # One lock for each available GPU (only one supported for now)
54+ GPUS = {}
55+ for i in range (torch .cuda .device_count ()):
56+ GPUS [i ] = Lock ()
57+
58+ if not GPUS :
59+ raise RuntimeError ("No GPUs detected. Check your config and try again." )
60+
5361# Supress some unnecessary warnings when loading the CLIPTextModel
5462logging .set_verbosity_error ()
5563
56- if not torch .cuda .is_available ():
57- raise RuntimeError ('No CUDA device available, exiting.' )
58-
5964# Load the autoencoder model which will be used to decode the latents into image space.
60- vae = AutoencoderKL .from_pretrained (MODELS ["vae" ]["name" ], subfolder = MODELS ["vae" ]["subfolder " ], use_auth_token = True )
65+ vae = AutoencoderKL .from_pretrained (MODELS ["vae" ]["name" ], subfolder = MODELS ["vae" ]["sub " ], use_auth_token = True )
6166
6267# Load the tokenizer and text encoder to tokenize and encode the text.
63- tokenizer = CLIPTokenizer .from_pretrained (MODELS ["tokenizer" ]["name" ], subfolder = MODELS ["tokenizer" ]["subfolder " ])
64- text_encoder = CLIPTextModel .from_pretrained (MODELS ["text_encoder" ]["name" ], subfolder = MODELS ["text_encoder" ]["subfolder " ])
68+ tokenizer = CLIPTokenizer .from_pretrained (MODELS ["tokenizer" ]["name" ], subfolder = MODELS ["tokenizer" ]["sub " ])
69+ text_encoder = CLIPTextModel .from_pretrained (MODELS ["text_encoder" ]["name" ], subfolder = MODELS ["text_encoder" ]["sub " ])
6570
6671# The UNet model for generating the latents.
67- unet = UNet2DConditionModel .from_pretrained (MODELS ["unet" ]["name" ], subfolder = MODELS ["unet" ]["subfolder" ], use_auth_token = True )
72+ unet = UNet2DConditionModel .from_pretrained (MODELS ["unet" ]["name" ], subfolder = MODELS ["unet" ]["sub" ], use_auth_token = True )
73+
74+ # The CompVis safety model.
75+ safety_feature_extractor = AutoFeatureExtractor .from_pretrained (MODELS ["safety" ]["name" ], subfolder = MODELS ["safety" ]["sub" ])
76+ safety_checker = StableDiffusionSafetyChecker .from_pretrained (MODELS ["safety" ]["name" ], subfolder = MODELS ["safety" ]["sub" ])
6877
6978# The noise scheduler
7079scheduler = LMSDiscreteScheduler (
7988text_encoder = text_encoder .to ('cuda' )
8089unet = unet .to ('cuda' )
8190
91+ def naughty (image ):
92+ ''' Returns True if naughty bits are detected, else False. '''
93+ safety_checker_input = safety_feature_extractor ([image ], return_tensors = "pt" )
94+ _ , has_nsfw_concept = safety_checker (images = [image ], clip_input = safety_checker_input .pixel_values )
95+ return has_nsfw_concept [0 ]
96+
8297def wait_for_gpu ():
8398 ''' Return the device name of first available GPU. Blocks until one is available and sets the lock. '''
8499 while True :
85100 gpu = random .choice (list (GPUS ))
86- if GPUS [gpu ][ 'lock' ] .acquire (timeout = 1 ):
101+ if GPUS [gpu ].acquire (timeout = 0.5 ):
87102 return gpu
88103
89104def generate_image (prompt , seed , steps , width = 512 , height = 512 , guidance = 7.5 ):
90- ''' Generate an image. Returns a FastAPI StreamingResponse. '''
91- generator = torch .manual_seed (seed )
92- batch_size = 1
93-
94- # Prep text
95- text_input = tokenizer (
96- [prompt ],
97- padding = "max_length" ,
98- max_length = tokenizer .model_max_length ,
99- truncation = True ,
100- return_tensors = "pt"
101- )
102- with torch .no_grad ():
103- text_embeddings = text_encoder (text_input .input_ids .to ('cuda' ))[0 ]
104- max_length = text_input .input_ids .shape [- 1 ]
105- uncond_input = tokenizer (
106- ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "pt"
107- )
108- with torch .no_grad ():
109- uncond_embeddings = text_encoder (uncond_input .input_ids .to ('cuda' ))[0 ]
110- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ]) # pylint: disable=no-member
111-
112- # Prep Scheduler
113- scheduler .set_timesteps (steps )
114-
115- # Prep latents
116- latents = torch .randn ( # pylint: disable=no-member
117- (batch_size , unet .in_channels , height // 8 , width // 8 ),
118- generator = generator ,
119- )
120- latents = latents .to ('cuda' )
121- latents = latents * scheduler .init_noise_sigma
105+ ''' Generate and return an image array using the first available GPU '''
122106
123- # Loop
124- with autocast ("cuda" ):
125- for _ , ts in tqdm (enumerate (scheduler .timesteps )):
126- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
127- latent_model_input = scheduler .scale_model_input (torch .cat ([latents ] * 2 ), ts ) # pylint: disable=no-member
107+ gpu = wait_for_gpu ()
108+ try :
109+ # Prep text
110+ text_input = tokenizer (
111+ [prompt ],
112+ padding = "max_length" ,
113+ max_length = tokenizer .model_max_length ,
114+ truncation = True ,
115+ return_tensors = "pt"
116+ )
117+ with torch .no_grad ():
118+ text_embeddings = text_encoder (text_input .input_ids .to ('cuda' ))[0 ]
119+ max_length = text_input .input_ids .shape [- 1 ]
120+ uncond_input = tokenizer (
121+ ["" ], padding = "max_length" , max_length = max_length , return_tensors = "pt"
122+ )
123+ with torch .no_grad ():
124+ uncond_embeddings = text_encoder (uncond_input .input_ids .to ('cuda' ))[0 ]
125+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ]) # pylint: disable=no-member
128126
129- # predict the noise residual
130- with torch .no_grad ():
131- noise_pred = unet (latent_model_input , ts , encoder_hidden_states = text_embeddings ).sample
127+ # Prep Scheduler
128+ scheduler .set_timesteps (steps )
132129
133- # perform guidance
134- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
135- noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond )
130+ # Prep latents
131+ latents = torch .randn ( # pylint: disable=no-member
132+ (1 , unet .in_channels , height // 8 , width // 8 ),
133+ generator = torch .manual_seed (seed ),
134+ )
135+ latents = latents .to ('cuda' )
136+ latents = latents * scheduler .init_noise_sigma
137+
138+ # Loop
139+ with autocast ("cuda" ):
140+ for _ , ts in tqdm (enumerate (scheduler .timesteps )):
141+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
142+ latent_model_input = scheduler .scale_model_input (torch .cat ([latents ] * 2 ), ts ) # pylint: disable=no-member
143+
144+ # predict the noise residual
145+ with torch .no_grad ():
146+ noise_pred = unet (latent_model_input , ts , encoder_hidden_states = text_embeddings ).sample
147+
148+ # perform guidance
149+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
150+ noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond )
151+
152+ # compute the previous noisy sample x_t -> x_t-1
153+ # latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
154+ latents = scheduler .step (noise_pred , ts , latents ).prev_sample
155+
156+ # scale and decode the image latents with vae
157+ latents = 1 / 0.18215 * latents
158+ with torch .no_grad ():
159+ image = vae .decode (latents ).sample
160+
161+ # Display
162+ image = (image / 2 + 0.5 ).clamp (0 , 1 )
163+ image = image .detach ().cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
164+ images = (image * 255 ).round ().astype ("uint8" )
165+
166+ return images [0 ]
167+ finally :
168+ GPUS [gpu ].release ()
136169
137- # compute the previous noisy sample x_t -> x_t-1
138- # latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
139- latents = scheduler .step (noise_pred , ts , latents ).prev_sample
170+ def safe_generate_image (prompt , seed , steps , width = 512 , height = 512 , guidance = 7.5 , nsfw = False ):
171+ ''' Generate an image and check NSFW. Returns a FastAPI StreamingResponse. '''
140172
141- # scale and decode the image latents with vae
142- latents = 1 / 0.18215 * latents
143- with torch .no_grad ():
144- image = vae .decode (latents ).sample
173+ image = generate_image (prompt , seed , steps , width , height , guidance )
145174
146- # Display
147- image = ( image / 2 + 0.5 ). clamp ( 0 , 1 )
148- image = image . detach (). cpu (). permute ( 0 , 2 , 3 , 1 ). numpy ()
149- images = ( image * 255 ). round (). astype ( "uint8" )
175+ if not nsfw and naughty ( image ):
176+ print ( "🍆 detected!!!1!" )
177+ prompt = "An adorable teddy bear running through a grassy field, early morning volumetric lighting"
178+ image = generate_image ( prompt , seed , steps , width , height , guidance )
150179
151- out = Image .fromarray (images [ 0 ] )
180+ out = Image .fromarray (image )
152181
153182 # Set the EXIF data. See PIL.ExifTags.TAGS to map numbers to names.
154183 exif = out .getexif ()
@@ -158,7 +187,6 @@ def generate_image(prompt, seed, steps, width=512, height=512, guidance=7.5):
158187
159188 buf = BytesIO ()
160189 out .save (buf , format = "JPEG" , quality = 85 , exif = exif )
161-
162190 buf .seek (0 )
163191
164192 return StreamingResponse (buf , media_type = "image/jpeg" , headers = {
@@ -177,22 +205,14 @@ async def generate(
177205 steps : Optional [int ] = Query (40 ),
178206 width : Optional [int ] = Query (512 ),
179207 height : Optional [int ] = Query (512 ),
208+ guidance : Optional [float ] = Query (7.5 ),
209+ nsfw : Optional [bool ] = Query (False ),
180210 ):
181211 ''' Generate an image with Stable Diffusion '''
182212
183- if width * height > 287744 :
184- raise HTTPException (
185- status_code = 422 ,
186- detail = 'Out of GPU memory. Total width * height must be < 287744 pixels.'
187- )
188-
189213 if seed < 0 :
190214 seed = random .randint (0 ,2 ** 64 - 1 )
191215
192216 prompt = prompt .strip ().replace ('\n ' , ' ' )
193217
194- gpu = wait_for_gpu ()
195- try :
196- return generate_image (prompt , seed , steps , width , height )
197- finally :
198- GPUS [gpu ]['lock' ].release ()
218+ return safe_generate_image (prompt , seed , steps , width , height , guidance , nsfw )
0 commit comments