@@ -187,7 +187,7 @@ def build(self, xz_shape, conditions_shape=None):
187187 self .c_huber = 0.00054 * ops .sqrt (xz_shape [- 1 ])
188188 self .c_huber2 = self .c_huber ** 2
189189
190- ## Calculate discretization schedule in advance
190+ # Calculate discretization schedule in advance
191191 # The Jax compiler requires fixed-size arrays, so we have
192192 # to store all the discretized_times in one matrix in advance
193193 # and later only access the relevant entries.
@@ -213,34 +213,24 @@ def build(self, xz_shape, conditions_shape=None):
213213 disc = ops .convert_to_numpy (self ._discretize_time (n ))
214214 discretized_times [i , : len (disc )] = disc
215215 discretization_map [n ] = i
216+
216217 # Finally, we convert the vectors to tensors
217218 self .discretized_times = ops .convert_to_tensor (discretized_times , dtype = "float32" )
218219 self .discretization_map = ops .convert_to_tensor (discretization_map )
219220
220- def call (
221- self ,
222- xz : Tensor ,
223- conditions : Tensor = None ,
224- inverse : bool = False ,
225- ** kwargs ,
226- ):
227- if inverse :
228- return self ._inverse (xz , conditions = conditions , ** kwargs )
229- return self ._forward (xz , conditions = conditions , ** kwargs )
230-
231- def _forward_train (self , x : Tensor , noise : Tensor , t : Tensor , conditions : Tensor = None , ** kwargs ) -> Tensor :
232- """Forward function for training. Calls consistency function with
233- noisy input
234- """
221+ def _forward_train (
222+ self , x : Tensor , noise : Tensor , t : Tensor , conditions : Tensor = None , training : bool = False , ** kwargs
223+ ) -> Tensor :
224+ """Forward function for training. Calls consistency function with noisy input"""
235225 inp = x + t * noise
236- return self .consistency_function (inp , t , conditions = conditions , ** kwargs )
226+ return self .consistency_function (inp , t , conditions = conditions , training = training )
237227
238228 def _forward (self , x : Tensor , conditions : Tensor = None , ** kwargs ) -> Tensor :
239229 # Consistency Models only learn the direction from noise distribution
240230 # to target distribution, so we cannot implement this function.
241231 raise NotImplementedError ("Consistency Models are not invertible" )
242232
243- def _inverse (self , z : Tensor , conditions : Tensor = None , ** kwargs ) -> Tensor :
233+ def _inverse (self , z : Tensor , conditions : Tensor = None , training : bool = False , ** kwargs ) -> Tensor :
244234 """Generate random draws from the approximate target distribution
245235 using the multistep sampling algorithm from [1], Algorithm 1.
246236
@@ -249,7 +239,9 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
249239 z : Tensor
250240 Samples from a standard normal distribution
251241 conditions : Tensor, optional, default: None
252- Conditions for a approximate conditional distribution
242+ Conditions for the approximate conditional distribution
243+ training : bool, optional, default: True
244+ Whether internal layers (e.g., dropout) should behave in train or inference mode.
253245 **kwargs : dict, optional, default: {}
254246 Additional keyword arguments. Include `steps` (default: 10) to
255247 adjust the number of sampling steps.
@@ -263,15 +255,17 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
263255 x = keras .ops .copy (z ) * self .max_time
264256 discretized_time = keras .ops .flip (self ._discretize_time (steps ), axis = - 1 )
265257 t = keras .ops .full ((* keras .ops .shape (x )[:- 1 ], 1 ), discretized_time [0 ], dtype = x .dtype )
266- x = self .consistency_function (x , t , conditions = conditions )
258+
259+ x = self .consistency_function (x , t , conditions = conditions , training = training )
260+
267261 for n in range (1 , steps ):
268262 noise = keras .random .normal (keras .ops .shape (x ), dtype = keras .ops .dtype (x ), seed = self .seed_generator )
269263 x_n = x + keras .ops .sqrt (keras .ops .square (discretized_time [n ]) - self .eps ** 2 ) * noise
270264 t = keras .ops .full_like (t , discretized_time [n ])
271- x = self .consistency_function (x_n , t , conditions = conditions )
265+ x = self .consistency_function (x_n , t , conditions = conditions , training = training )
272266 return x
273267
274- def consistency_function (self , x : Tensor , t : Tensor , conditions : Tensor = None , ** kwargs ) -> Tensor :
268+ def consistency_function (self , x : Tensor , t : Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
275269 """Compute consistency function.
276270
277271 Parameters
@@ -282,16 +276,16 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
282276 Vector of time samples in [eps, T]
283277 conditions : Tensor
284278 The conditioning vector
285- **kwargs : dict , optional, default: {}
286- Additional keyword arguments passed to the network .
279+ training : bool , optional, default: True
280+ Whether internal layers (e.g., dropout) should behave in train or inference mode .
287281 """
288282
289283 if conditions is not None :
290284 xtc = ops .concatenate ([x , t , conditions ], axis = - 1 )
291285 else :
292286 xtc = ops .concatenate ([x , t ], axis = - 1 )
293287
294- f = self .output_projector (self .subnet (xtc , ** kwargs ))
288+ f = self .output_projector (self .subnet (xtc , training = training ))
295289
296290 # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
297291 # Thus, we can do a cross product with the time vector which is (batch_size, 1) for
0 commit comments