Skip to content

Commit 0598346

Browse files
authored
Merge branch 'main' into stes/better-goodness-of-fit
2 parents c94f5ae + 7e74eda commit 0598346

File tree

6 files changed

+112
-0
lines changed

6 files changed

+112
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pt
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
FROM python:3.12-slim AS base
2+
RUN pip install torch --index-url https://download.pytorch.org/whl/cpu
3+
RUN apt-get update && \
4+
apt-get install -y --no-install-recommends git && \
5+
rm -rf /var/lib/apt/lists/*
6+
7+
FROM base AS cebra-0.4.0-scikit-learn-1.4
8+
RUN pip install cebra==0.4.0 "scikit-learn<1.5"
9+
WORKDIR /app
10+
COPY create_model.py .
11+
RUN python create_model.py
12+
13+
FROM base AS cebra-0.4.0-scikit-learn-1.6
14+
RUN pip install cebra==0.4.0 "scikit-learn>=1.6"
15+
WORKDIR /app
16+
COPY create_model.py .
17+
RUN python create_model.py
18+
19+
FROM base AS cebra-rc-scikit-learn-1.4
20+
# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class.
21+
# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053
22+
RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn<1.5"
23+
WORKDIR /app
24+
COPY create_model.py .
25+
RUN python create_model.py
26+
27+
FROM base AS cebra-rc-scikit-learn-1.6
28+
# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class.
29+
# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053
30+
RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn>=1.6"
31+
WORKDIR /app
32+
COPY create_model.py .
33+
RUN python create_model.py
34+
35+
FROM scratch
36+
COPY --from=cebra-0.4.0-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.4.pt
37+
COPY --from=cebra-0.4.0-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.6.pt
38+
COPY --from=cebra-rc-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.4.pt
39+
COPY --from=cebra-rc-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.6.pt
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Helper script to build CEBRA checkpoints
2+
3+
This script builds CEBRA checkpoints for different versions of scikit-learn and CEBRA.
4+
To build all models, run:
5+
6+
```bash
7+
./generate.sh
8+
```
9+
10+
The models are currently also stored in git directly due to their small size.
11+
12+
Related issue: https://github.com/AdaptiveMotorControlLab/CEBRA/issues/207
13+
Related test: tests/test_sklearn_legacy.py
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
3+
import cebra
4+
5+
neural_data = np.random.normal(0, 1, (1000, 30)) # 1000 samples, 30 features
6+
cebra_model = cebra.CEBRA(model_architecture="offset10-model",
7+
batch_size=512,
8+
learning_rate=1e-4,
9+
max_iterations=10,
10+
time_offsets=10,
11+
num_hidden_units=16,
12+
output_dimension=8,
13+
verbose=True)
14+
cebra_model.fit(neural_data)
15+
cebra_model.save("cebra_model.pt")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
DOCKER_BUILDKIT=1 docker build --output type=local,dest=. .

tests/test_sklearn_legacy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pathlib
2+
import urllib.request
3+
4+
import numpy as np
5+
import pytest
6+
7+
from cebra.integrations.sklearn.cebra import CEBRA
8+
9+
MODEL_VARIANTS = [
10+
"cebra-0.4.0-scikit-learn-1.4", "cebra-0.4.0-scikit-learn-1.6",
11+
"cebra-rc-scikit-learn-1.4", "cebra-rc-scikit-learn-1.6"
12+
]
13+
14+
15+
@pytest.mark.parametrize("model_variant", MODEL_VARIANTS)
16+
def test_load_legacy_model(model_variant):
17+
"""Test loading a legacy CEBRA model."""
18+
19+
X = np.random.normal(0, 1, (1000, 30))
20+
21+
model_path = pathlib.Path(
22+
__file__
23+
).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt"
24+
25+
if not model_path.exists():
26+
url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt"
27+
urllib.request.urlretrieve(url, model_path)
28+
29+
loaded_model = CEBRA.load(model_path)
30+
31+
assert loaded_model.model_architecture == "offset10-model"
32+
assert loaded_model.output_dimension == 8
33+
assert loaded_model.num_hidden_units == 16
34+
assert loaded_model.time_offsets == 10
35+
36+
output = loaded_model.transform(X)
37+
assert isinstance(output, np.ndarray)
38+
assert output.shape[1] == loaded_model.output_dimension
39+
40+
assert hasattr(loaded_model, "state_dict_")
41+
assert hasattr(loaded_model, "n_features_")

0 commit comments

Comments
 (0)