@@ -88,6 +88,7 @@ def __init__(
8888 clip_sample : bool = True ,
8989 set_alpha_to_one : bool = True ,
9090 tensor_format : str = "pt" ,
91+ prediction_type : str = "epsilon"
9192 ):
9293 if trained_betas is not None :
9394 self .betas = np .asarray (trained_betas )
@@ -115,6 +116,7 @@ def __init__(
115116 self .clip_sample = clip_sample
116117 self .set_alpha_to_one = set_alpha_to_one
117118 self .tensor_format = tensor_format
119+ self .prediction_type = prediction_type
118120
119121 # At every step in ddim, we are looking into the previous alphas_cumprod
120122 # For the final step, there is no previous alphas_cumprod because we are already at 0
@@ -217,8 +219,14 @@ def step(
217219
218220 # 3. compute predicted original sample from predicted noise also called
219221 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
220- pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
221-
222+ if self .config .prediction_type == "epsilon" :
223+ pred_original_sample = (sample - beta_prod_t ** (0.5 ) * model_output ) / alpha_prod_t ** (0.5 )
224+ pred_epsilon = model_output
225+ elif self .config .prediction_type == "v_prediction" :
226+ pred_original_sample = (alpha_prod_t ** 0.5 ) * sample - (beta_prod_t ** 0.5 ) * model_output
227+ pred_epsilon = (alpha_prod_t ** 0.5 ) * model_output + (beta_prod_t ** 0.5 ) * sample
228+ else :
229+ raise ValueError ("Unknown prediction_type" )
222230 # 4. Clip "predicted x_0"
223231 if self .config .clip_sample :
224232 pred_original_sample = self .clip (pred_original_sample , - 1 , 1 )
@@ -230,10 +238,10 @@ def step(
230238
231239 if use_clipped_model_output :
232240 # the model_output is always re-derived from the clipped x_0 in Glide
233- model_output = (sample - alpha_prod_t ** (0.5 ) * pred_original_sample ) / beta_prod_t ** (0.5 )
241+ pred_epsilon = (sample - alpha_prod_t ** (0.5 ) * pred_original_sample ) / beta_prod_t ** (0.5 )
234242
235243 # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
236- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2 ) ** (0.5 ) * model_output
244+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2 ) ** (0.5 ) * pred_epsilon
237245
238246 # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
239247 prev_sample = alpha_prod_t_prev ** (0.5 ) * pred_original_sample + pred_sample_direction
0 commit comments