Skip to content

Commit 5e86c32

Browse files
committed
cleaned the backend and made it modular, minor changes in frontend
1 parent cc34ec0 commit 5e86c32

30 files changed

+985
-1760
lines changed
File renamed without changes.
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import ftfy
77
import regex as re
88

9-
import tensorflow as tf
9+
# import tensorflow as tf
10+
# from tensorflow import keras
11+
1012

1113

1214
@lru_cache()
@@ -24,10 +26,6 @@ def default_bpe():
2426
if os.path.exists(p2):
2527
return p2
2628
assert False
27-
return tf.keras.utils.get_file(
28-
"bpe_simple_vocab_16e6.txt.gz",
29-
"https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true",
30-
)
3129

3230

3331
@lru_cache()
@@ -37,7 +35,7 @@ def bytes_to_unicode():
3735
The reversible bpe codes work on unicode strings.
3836
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
3937
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
40-
This is a signficant percentage of your normal, say, 32K bpe vocab.
38+
This is a significant percentage of your normal, say, 32K bpe vocab.
4139
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
4240
And avoids mapping to whitespace/control characters the bpe code barfs on.
4341
"""
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
print("backend")
2+
3+
from stdin_input import is_avail, get_input
4+
import argparse
5+
from PIL import Image
6+
import json
7+
import random
8+
import multiprocessing
9+
from downloader import ProgressBarDownloader
10+
import sys
11+
import copy
12+
import math
13+
import time
14+
import traceback
15+
from stable_diffusion import StableDiffusion
16+
17+
# b2py t2im {"prompt": "sun glasses" , "W":640 , "H" : 640 , "num_imgs" : 10 , "input_image":"/Users/divamgupta/Downloads/inn.png" , "mask_image" : "/Users/divamgupta/Downloads/maa.png" , "is_inpaint":true }
18+
19+
20+
21+
from pathlib import Path
22+
import os
23+
24+
home_path = Path.home()
25+
26+
projects_root_path = os.path.join(home_path, ".diffusionbee")
27+
28+
if not os.path.isdir(projects_root_path):
29+
os.mkdir(projects_root_path)
30+
31+
32+
defualt_data_root = os.path.join(projects_root_path, "images")
33+
34+
35+
if not os.path.isdir(defualt_data_root):
36+
os.mkdir(defualt_data_root)
37+
38+
39+
40+
class Unbuffered(object):
41+
def __init__(self, stream):
42+
self.stream = stream
43+
44+
def write(self, data):
45+
self.stream.write(data)
46+
self.stream.flush()
47+
48+
def writelines(self, datas):
49+
self.stream.writelines(datas)
50+
self.stream.flush()
51+
52+
def __getattr__(self, attr):
53+
return getattr(self.stream, attr)
54+
55+
56+
sys.stdout = Unbuffered(sys.stdout)
57+
58+
59+
def download_weights():
60+
global p_14 , p_14_np
61+
62+
print("sdbk mltl Loading Model")
63+
64+
is_downloaded = False
65+
for _ in range(10):
66+
try:
67+
print("sdbk mltl Downloading Model 1/2")
68+
69+
p_14 = "/Users/divamgupta/Downloads/sd-v1-4.tdict"
70+
71+
print("sdbk mltl Downloading Model 2/2")
72+
73+
p_14_np = "/Users/divamgupta/Downloads/sd-v1-5-inpainting.tdict"
74+
75+
is_downloaded = True
76+
break
77+
except Exception as e:
78+
pass
79+
80+
time.sleep(10)
81+
82+
if not is_downloaded:
83+
raise ValueError("Unable to download the model weights. Please try again and make sure you have free space and a working internet connection.")
84+
85+
86+
87+
88+
89+
def process_opt(d, generator):
90+
91+
batch_size = int(d['batch_size'])
92+
n_imgs = math.ceil(d['num_imgs'] / batch_size)
93+
94+
if d['model_id'] == 1:
95+
model_mode = "inpaint_15"
96+
tdict_path = p_14_np
97+
print("sdbk mdvr 1.5_inp")
98+
else:
99+
tdict_path = p_14
100+
print("sdbk mdvr 1.4")
101+
if "input_image" in d and d['input_image' ] is not None and d['input_image'] != "" :
102+
model_mode = "img2img"
103+
else:
104+
model_mode = "txt2img"
105+
106+
if d['model_id'] == -1:
107+
cust_model_path = d['custom_model_path']
108+
tdict_path = cust_model_path
109+
print("sdbk mdvr custom_" + cust_model_path.split(os.sep)[-1].split(".")[0])
110+
111+
for i in range(n_imgs):
112+
if 'seed' in d:
113+
seed = d['seed']
114+
else:
115+
seed = None
116+
if 'soft_seed' in d:
117+
soft_seed = d['soft_seed']
118+
else:
119+
soft_seed = None
120+
121+
img = generator.generate(
122+
d['prompt'],
123+
img_height=d["H"], img_width=d["W"],
124+
num_steps=d['ddim_steps'],
125+
guidance_scale=d['scale'],
126+
temperature=1,
127+
batch_size=batch_size,
128+
seed=seed,
129+
soft_seed=soft_seed,
130+
img_id=i,
131+
negative_prompt=d['negative_prompt'],
132+
input_image=d['input_image'],
133+
tdict_path=tdict_path,
134+
mode=model_mode,
135+
mask_image=d['mask_image'],
136+
input_image_strength=(float(d['img_strength'])),
137+
)
138+
if img is None:
139+
return
140+
141+
for i in range(len(img)):
142+
s = ''.join(filter(str.isalnum, str(d['prompt'])[:30] ))
143+
fpath = os.path.join(defualt_data_root , "%s_%d.png"%(s , random.randint(0 ,100000000)) )
144+
145+
Image.fromarray(img[i]).save(fpath)
146+
print("sdbk nwim %s"%(fpath) )
147+
148+
149+
150+
151+
def main():
152+
153+
154+
global p_14 , p_14_np
155+
download_weights()
156+
print("sdbk mltl Loading Model")
157+
158+
def callback(state="" , progress=-1):
159+
print("sdbk dnpr "+str(progress) )
160+
if state != "Generating":
161+
print("sdbk gnms " + state)
162+
163+
if is_avail():
164+
if "__stop__" in get_input():
165+
return "stop"
166+
167+
generator = StableDiffusion( p_14 , model_name="sd_1x", callback=callback)
168+
169+
170+
default_d = { "W" : 512 , "H" : 512, "num_imgs":1 , "ddim_steps" : 25 ,
171+
"scale" : 7.5, "batch_size":1 , "input_image" : None, "img_strength": 0.5
172+
, "negative_prompt" : "" , "mask_image" : None, "model_id": 0 , "custom_model_path":None}
173+
174+
175+
print("sdbk mdld")
176+
177+
while True:
178+
print("sdbk inrd") # input ready
179+
180+
inp_str = get_input()
181+
182+
if inp_str.strip() == "":
183+
continue
184+
185+
if not "b2py t2im" in inp_str or "__stop__" in inp_str:
186+
continue
187+
inp_str = inp_str.replace("b2py t2im" , "").strip()
188+
try:
189+
d_ = json.loads(inp_str)
190+
d = copy.deepcopy(default_d)
191+
d.update(d_)
192+
print("sdbk inwk") # working on the input
193+
194+
process_opt(d, generator)
195+
196+
except Exception as e:
197+
traceback.print_exc()
198+
print("sdbk errr %s"%(str(e)))
199+
print("py2b eror " + str(e))
200+
201+
202+
203+
204+
if __name__ == "__main__":
205+
multiprocessing.freeze_support() # for pyinstaller
206+
main()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
2+
import numpy as np
3+
import time
4+
5+
class ModelInterface:
6+
default_float_type = 'float16'
7+
8+
def __init__(self , *args , **k ):
9+
pass
10+
11+
def run_unet(self, time_emb, text_emb, unet_inp):
12+
time.sleep(1.4)
13+
return np.copy(unet_inp)
14+
15+
def run_dec(self, unet_out):
16+
time.sleep(1.4)
17+
return np.zeros((unet_out.shape[0] , unet_out.shape[1]*8 , unet_out.shape[2]*8 , unet_out.shape[3]))
18+
19+
def run_text_enc(self, tokens, pos_ids):
20+
time.sleep(1.4)
21+
return np.zeros((tokens.shape[0] , 77 , 768))
22+
23+
def run_enc(self, inp):
24+
time.sleep(1.4)
25+
return np.zeros((unet_out.shape[0] , unet_out.shape[1]//8 , unet_out.shape[2]//8 , unet_out.shape[3]))
26+
27+
def destroy(self):
28+
pass
29+
30+
def load_from_tdict(self, tdict_path):
31+
pass

0 commit comments

Comments
 (0)