@@ -116,7 +116,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
116116 """Compute the goodness of fit score on a *single session* dataset on the model.
117117
118118 This function uses the :func:`infonce_loss` function to compute the InfoNCE loss
119- for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
119+ for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
120120 to derive the goodness of fit from the InfoNCE loss.
121121
122122 Args:
@@ -180,10 +180,11 @@ def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
180180 return infonce_to_goodness_of_fit (infonce , model )
181181
182182
183- def infonce_to_goodness_of_fit (infonce : Union [float , np .ndarray ],
184- model : Optional [cebra_sklearn_cebra .CEBRA ] = None ,
185- batch_size : Optional [int ] = None ,
186- num_sessions : Optional [int ] = None ) -> Union [float , np .ndarray ]:
183+ def infonce_to_goodness_of_fit (
184+ infonce : Union [float , np .ndarray ],
185+ model : Optional [cebra_sklearn_cebra .CEBRA ] = None ,
186+ batch_size : Optional [int ] = None ,
187+ num_sessions : Optional [int ] = None ) -> Union [float , np .ndarray ]:
187188 """Given a trained CEBRA model, return goodness of fit metric.
188189
189190 The goodness of fit ranges from 0 (lowest meaningful value)
@@ -208,7 +209,7 @@ def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
208209
209210 Args:
210211 infonce: The InfoNCE loss, either a single value or an iterable of values.
211- model: The trained CEBRA model.
212+ model: The trained CEBRA model.
212213 batch_size: The batch size used to train the model.
213214 num_sessions: The number of sessions used to train the model.
214215
@@ -221,27 +222,32 @@ def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
221222 """
222223 if model is not None :
223224 if batch_size is not None or num_sessions is not None :
224- raise ValueError ("batch_size and num_sessions should not be provided if model is provided." )
225+ raise ValueError (
226+ "batch_size and num_sessions should not be provided if model is provided."
227+ )
225228 if not hasattr (model , "state_dict_" ):
226229 raise RuntimeError ("Fit the CEBRA model first." )
227230 if model .batch_size is None :
228231 raise ValueError (
229232 "Computing the goodness of fit is not yet supported for "
230- "models trained on the full dataset (batchsize = None). "
231- )
233+ "models trained on the full dataset (batchsize = None). " )
232234 batch_size = model .batch_size
233235 num_sessions = model .num_sessions_
234236 if num_sessions is None :
235237 num_sessions = 1
238+
239+ if model .batch_size is None :
240+ raise ValueError (
241+ "Computing the goodness of fit is not yet supported for "
242+ "models trained on the full dataset (batchsize = None). " )
236243 else :
237244 if batch_size is None or num_sessions is None :
238245 raise ValueError (
239- f"batch_size ({ batch_size } ) and num_sessions ({ num_sessions } )"
240- f"should be provided if model is not provided."
241- )
246+ f"batch_size ({ batch_size } ) and num_sessions ({ num_sessions } )"
247+ f"should be provided if model is not provided." )
242248
243249 nats_to_bits = np .log2 (np .e )
244- chance_level = np .log (model . batch_size * num_sessions )
250+ chance_level = np .log (batch_size * num_sessions )
245251 return (chance_level - infonce ) * nats_to_bits
246252
247253
0 commit comments