Skip to content

Commit 3771990

Browse files
committed
fix tests
1 parent fd8e7cd commit 3771990

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
116116
"""Compute the goodness of fit score on a *single session* dataset on the model.
117117
118118
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss
119-
for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
119+
for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
120120
to derive the goodness of fit from the InfoNCE loss.
121121
122122
Args:
@@ -180,10 +180,11 @@ def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
180180
return infonce_to_goodness_of_fit(infonce, model)
181181

182182

183-
def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
184-
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
185-
batch_size: Optional[int] = None,
186-
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
183+
def infonce_to_goodness_of_fit(
184+
infonce: Union[float, np.ndarray],
185+
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
186+
batch_size: Optional[int] = None,
187+
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
187188
"""Given a trained CEBRA model, return goodness of fit metric.
188189
189190
The goodness of fit ranges from 0 (lowest meaningful value)
@@ -208,7 +209,7 @@ def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
208209
209210
Args:
210211
infonce: The InfoNCE loss, either a single value or an iterable of values.
211-
model: The trained CEBRA model.
212+
model: The trained CEBRA model.
212213
batch_size: The batch size used to train the model.
213214
num_sessions: The number of sessions used to train the model.
214215
@@ -221,27 +222,32 @@ def infonce_to_goodness_of_fit(infonce: Union[float, np.ndarray],
221222
"""
222223
if model is not None:
223224
if batch_size is not None or num_sessions is not None:
224-
raise ValueError("batch_size and num_sessions should not be provided if model is provided.")
225+
raise ValueError(
226+
"batch_size and num_sessions should not be provided if model is provided."
227+
)
225228
if not hasattr(model, "state_dict_"):
226229
raise RuntimeError("Fit the CEBRA model first.")
227230
if model.batch_size is None:
228231
raise ValueError(
229232
"Computing the goodness of fit is not yet supported for "
230-
"models trained on the full dataset (batchsize = None). "
231-
)
233+
"models trained on the full dataset (batchsize = None). ")
232234
batch_size = model.batch_size
233235
num_sessions = model.num_sessions_
234236
if num_sessions is None:
235237
num_sessions = 1
238+
239+
if model.batch_size is None:
240+
raise ValueError(
241+
"Computing the goodness of fit is not yet supported for "
242+
"models trained on the full dataset (batchsize = None). ")
236243
else:
237244
if batch_size is None or num_sessions is None:
238245
raise ValueError(
239-
f"batch_size ({batch_size}) and num_sessions ({num_sessions})"
240-
f"should be provided if model is not provided."
241-
)
246+
f"batch_size ({batch_size}) and num_sessions ({num_sessions})"
247+
f"should be provided if model is not provided.")
242248

243249
nats_to_bits = np.log2(np.e)
244-
chance_level = np.log(model.batch_size * num_sessions)
250+
chance_level = np.log(batch_size * num_sessions)
245251
return (chance_level - infonce) * nats_to_bits
246252

247253

tests/test_sklearn_metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,14 @@ def test_infonce_to_goodness_of_fit(seed):
493493
num_sessions=1)
494494

495495
# Test with unfitted model
496-
unfitted_model = cebra_sklearn_cebra.CEBRA()
496+
unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5)
497497
with pytest.raises(RuntimeError, match="Fit the CEBRA model first"):
498498
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
499499
model=unfitted_model)
500500

501501
# Test with model having batch_size=None
502-
none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None)
502+
none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None,
503+
max_iterations=5)
503504
none_batch_model.fit(X)
504505
with pytest.raises(ValueError, match="Computing the goodness of fit"):
505506
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,

0 commit comments

Comments
 (0)