1111import io
1212from dataclasses import dataclass
1313import os
14- import importlib
1514import sys
1615
1716MAX_TEXT_LEN = 77
1817
19- USE_DUMMY_INTERFACE = False
2018
2119dir_path = os .path .dirname (os .path .realpath (__file__ ))
2220
23- sys .path .append (os .path .join (dir_path , "../model_converter" ))
2421from tdict import TDict
2522
26- # get the model interface form the environ
27- if not USE_DUMMY_INTERFACE :
28- model_interface_path = os .environ .get ('MODEL_INTERFACE_PATH' ) or "../stable_diffusion_tf_models"
29-
30- print ("model_interface_path" , model_interface_path )
31-
32- if model_interface_path [- 1 ] == "/" :
33- model_interface_path = model_interface_path [:- 1 ]
34-
35- module_name = model_interface_path .split ("/" )[- 1 ]
36- module_path = "/" .join (model_interface_path .split ("/" )[:- 1 ])
37-
38-
39- sys .path .append ( os .path .join (dir_path , module_path ) )
40-
41-
42- ModelInterface = importlib .import_module ( module_name + ".interface" ).ModelInterface
43- else :
44- from fake_interface import ModelInterface
45-
4623from schedulers .scheduling_ddim import DDIMScheduler
4724from schedulers .scheduling_lms_discrete import LMSDiscreteScheduler
4825from schedulers .scheduling_pndm import PNDMScheduler
@@ -145,8 +122,9 @@ def dummy_callback(state="" , progress=-1):
145122 pass
146123
147124class StableDiffusion :
148- def __init__ (self , tdict_path , model_name = "sd_1x" , callback = None ):
125+ def __init__ (self , ModelInterfaceClass , tdict_path , model_name = "sd_1x" , callback = None ):
149126
127+ self .ModelInterfaceClass = ModelInterfaceClass
150128
151129 if callback is None :
152130 callback = dummy_callback
@@ -158,9 +136,9 @@ def __init__(self , tdict_path , model_name="sd_1x", callback=None ):
158136
159137 self .current_model_name = model_name
160138 self .current_tdict_path = tdict_path
161- self .current_dtype = ModelInterface .default_float_type
139+ self .current_dtype = self . ModelInterfaceClass .default_float_type
162140
163- self .model = ModelInterface ( TDict (self .current_tdict_path ), dtype = self .current_dtype , model_name = self .current_model_name )
141+ self .model = self . ModelInterfaceClass ( TDict (self .current_tdict_path ), dtype = self .current_dtype , model_name = self .current_model_name )
164142
165143
166144 def prepare_model_interface (self , sd_run = None ):
@@ -177,7 +155,7 @@ def prepare_model_interface(self , sd_run=None ):
177155 print ("Creating model interface" )
178156 assert tdict_path is not None
179157 self .model .destroy ()
180- self .model = ModelInterface (TDict (tdict_path ) , dtype = dtype , model_name = model_name )
158+ self .model = self . ModelInterfaceClass (TDict (tdict_path ) , dtype = dtype , model_name = model_name )
181159 self .current_tdict_path = tdict_path
182160 self .current_dtype = dtype
183161 self .current_model_name = model_name
0 commit comments