@@ -178,9 +178,11 @@ def goodness_of_fit_history(model):
178178 return infonce_to_goodness_of_fit (infonce , model )
179179
180180
181- def infonce_to_goodness_of_fit (infonce : Union [float , Iterable [float ]],
182- model : cebra_sklearn_cebra .CEBRA ) -> np .ndarray :
183- """Given a trained CEBRA model, return goodness of fit metric
181+ def infonce_to_goodness_of_fit (infonce : Union [float , np .ndarray ],
182+ model : Optional [cebra_sklearn_cebra .CEBRA ] = None ,
183+ batch_size : Optional [int ] = None ,
184+ num_sessions : Optional [int ] = None ) -> Union [float , np .ndarray ]:
185+ """Given a trained CEBRA model, return goodness of fit metric.
184186
185187 The goodness of fit ranges from 0 (lowest meaningful value)
186188 to a positive number with the unit "bits", the higher the
@@ -199,27 +201,41 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]],
199201
200202 S = \\ log N - \\ text{InfoNCE}
201203
204+ To use this function, either provide a trained CEBRA model or the
205+ batch size and number of sessions.
206+
202207 Args:
208+ infonce: The InfoNCE loss, either a single value or an iterable of values.
203209 model: The trained CEBRA model
210+ batch_size: The batch size used to train the model.
211+ num_sessions: The number of sessions used to train the model.
204212
205213 Returns:
206214 Numpy array containing the goodness of fit values, measured in bits
207215
208216 Raises:
209217 RuntimeError: If the provided model is not fit to data.
218+ ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
210219 """
211- if not hasattr (model , "state_dict_" ):
212- raise RuntimeError ("Fit the CEBRA model first." )
213- if model .batch_size is None :
214- raise ValueError (
215- "Computing the goodness of fit is not yet supported for "
216- "models trained on the full dataset (batchsize = None). "
217- )
220+ if model is not None :
221+ if batch_size is not None or num_sessions is not None :
222+ raise ValueError ("batch_size and num_sessions should not be provided if model is provided." )
223+ if not hasattr (model , "state_dict_" ):
224+ raise RuntimeError ("Fit the CEBRA model first." )
225+ if model .batch_size is None :
226+ raise ValueError (
227+ "Computing the goodness of fit is not yet supported for "
228+ "models trained on the full dataset (batchsize = None). "
229+ )
230+ batch_size = model .batch_size
231+ num_sessions = model .num_sessions_
232+ if num_sessions is None :
233+ num_sessions = 1
234+ else :
235+ if batch_size is None or num_sessions is None :
236+ raise ValueError ("batch_size should be provided if model is not provided." )
218237
219238 nats_to_bits = np .log2 (np .e )
220- num_sessions = model .num_sessions_
221- if num_sessions is None :
222- num_sessions = 1
223239 chance_level = np .log (model .batch_size * num_sessions )
224240 return (chance_level - infonce ) * nats_to_bits
225241
0 commit comments