Skip to content

Conversation

@gonlairo
Copy link
Contributor

@gonlairo gonlairo commented Feb 17, 2025

xCEBRA

eXplainable CEBRA 🔎🦓

This PR adds the following features:

  • multiobjective solver -> fit multiple subspaces with a new API
  • attribution methods (via captum), including our new method, inverted neuron gradient
  • regularized contrastive learning using jacobian regularization (required for identifiable attribution maps, but also useful for regularizing training more generally); aka xCEBRA

This code supports the following paper:

https://openreview.net/forum?id=aGrCXoTB4P

@inproceedings{ schneider2025timeseries, title={Time-series attribution maps with regularized contrastive learning}, author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie W Mathis}, booktitle={The 28th International Conference on Artificial Intelligence and Statistics}, year={2025}, url={https://openreview.net/forum?id=aGrCXoTB4P} } 

Abstract:

Gradient-based attribution methods aim to explain decisions of deep learning models, but so far lack identifiability guarantees. Here, we propose a method to generate attribution maps with identifiability guarantees by developing a regularized contrastive learning algorithm (RegCL) trained on time-series data. We show theoretically that RegCL has favorable properties for identifying the Jacobian matrix of the data generating process. Empirically, we demonstrate robust approximation of zero vs. non-zero entries in the ground-truth attribution map on synthetic datasets, and significant improvements across previous attribution methods based on feature ablation, Shapley values, and other gradient-based methods. Our work constitutes a first example of identifiable inference of time-series attribution maps, and opens avenues better understanding of time-series data, such as for neural dynamics and decision-processes within neural networks.

Outline of the Method:

FIG1

Identifiable attribution maps for time-series data. Using time-series data (such as neural data
recorded during navigation, as depicted), our inference framework estimates the ground-truth Jacobian matrix
Jg (i.e., x is the observed neural data linked to latents z and c, where c is the explicit [auxiliary] behavioral
variable that would be linked to grid cells) by identifying the inverse data generation process up to a linear
indeterminacy L. Then, we estimate the Jacobian Jf of the encoder model (f) by minimizing a generalized
InfoNCE objective. Inverting this Jacobian J+f , which approximates Jg, allows us to construct the attributions.

gonlairo and others added 6 commits February 17, 2025 21:16
* Add multiobjective solver and regularized training * Add example for multiobjective training * Add jacobian regularizer and SAM * update license headers * add api draft for multiobjective training * add all necessary modules to run the complete xcebra pipeline * add notebooks to reproduce xcebra pipeline * add first working notebook * add notebook with hybrid learning * add notebook with creation of synthetic data * add notebook with hybrid training * add plot with R2 for different parts of the embedding * add new API * update api wrapper with more checks and messages * add tests and notebook with new api * merge xcebra into attribution * separate xcebra dataset from cebra * some minor refactoring of cebra dataset * separate xcebra loader from cebra * remove xcebra distributions from cebra * minor refactoring with distributions * separate xcebra criterions from cebra * minor refactoring on criterion * separate xcebra models/criterions/layers from cebra * refactoring multiobjective * more refactoring... * separate xcebra solvers from cebra * more refactoring * move xcebra to its own package * move more files into xcebra package * more files and remove changes with the registry * remove unncessary import * add folder structure * move back distributions * add missing init * remove wrong init * make loader and dataset run with new imports * making it run! * make attribution run * Run pre-commit * move xcebra repo one level up * update gitignore and add __init__ from data * add init to distributions * add correct init for attribution pacakge * add correct init for model package * fix remaining imports * fix tests * add examples back to xcebra repo * update imports from graphs_xcebra * add setup.py to create a package * update imports of graph_xcebra * update notebooks * Formatting code for submission Co-authored-by: Rodrigo Gonzalez <gonlairo@gmail.com> * move test into xcebra * Add README * move distributions back to main package * clean up examples * adapt tests * Add LICENSE * add train/eval notebook again * add notebook with clean results * rm synthetic data * change name from xcebra to regcl * change names of modules and adapt imports * change name from graphs_xcebra to synthetic_data * Integrate into CEBRA * Fix remaining imports and make notebook runnable * Add dependencies, add version flag * Remove synthetic data files * reset dockerfile, move vmf * apply pre-commit * Update notice * add some docstrings * Apply license headers * add new scd notebook * add notebook with scd --------- Co-authored-by: Steffen Schneider <stes@hey.com>
* bump version * update dockerfile * fix progress bar * remove outdated test * rename models
@stes stes changed the title Aistats2025 Add xCEBRA implementation (AISTATS 2025) Feb 17, 2025
@MMathisLab
Copy link
Member

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left comments throughout, thank you!!

@cla-bot cla-bot bot added the CLA signed label Feb 18, 2025
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a few edits:

  • The old MultiobjectiveSolver is now again accessible (important for the hybrid model in Fig 2 in CEBRA paper), it is now called LegacyMultiobjectiveSolver. For an end user using the sklearn API with hybrid=True, this can now still be used.
  • Reverted some changes from the research code base not important for release
  • added additional tests, incl. integration tests from the notebooks
  • (resolved some additional review comments)
@stes
Copy link
Member

stes commented Feb 19, 2025

next:

  • check docstring coverage and write docstrings
  • remove ratinabox==1.8 and ephysiopy==1.9.62 deps for xcebra? I think not needed for core functionality. when dropping, could be added to the docs somewhere
  • fix and consolidate naming of some of the newly added classes
  • cebra.distributions.DeltaVMFDistribution seems missing, check!
@stes
Copy link
Member

stes commented Feb 19, 2025

tests build!

image

@stes stes mentioned this pull request Apr 17, 2025
@stes stes deleted the branch AdaptiveMotorControlLab:main April 18, 2025 11:32
@stes stes closed this Apr 18, 2025
@stes stes reopened this Apr 18, 2025
@stes stes changed the base branch from stes/upgrade-docs-rebased to main April 18, 2025 11:49
@stes
Copy link
Member

stes commented Apr 18, 2025

Merged #241 now, diff should be clean against main.

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few more comments regarding naming and compatability

python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets,integrations]'
pip install '.[dev,datasets,integrations,xcebra]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xcebra should not be it's own sub-install

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Dockerfile Outdated
WORKDIR /build
COPY --from=wheel /build/dist/${WHEEL} .
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]'
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets,xcebra]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove xcebra; should be in main cebra package

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

setup.cfg Outdated
xcebra =
captum
cvxpy
ratinabox==1.8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove xcebra and move to integrations

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed some/most of the review comments

python -m pip install --upgrade pip setuptools wheel
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets,integrations]'
pip install '.[dev,datasets,integrations,xcebra]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Dockerfile Outdated
WORKDIR /build
COPY --from=wheel /build/dist/${WHEEL} .
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]'
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets,xcebra]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class SingleSessionHybridSolver(abc_.MultiobjectiveSolver):
"""Single session training, contrasting neural data against behavior."""

log: Dict = dataclasses.field(default_factory=lambda: ({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed, unclear why added here

setup.cfg Outdated
xcebra =
captum
cvxpy
ratinabox==1.8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

setup.cfg Outdated
cvxpy
ratinabox==1.8
scikit-image
ephysiopy==1.9.62
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@stes
Copy link
Member

stes commented Apr 20, 2025

ok, remaining thing is better/more docs. I'll make some progress on that tomorrow!
(the failing test is docs coverage)

- setting coverage threshold to 80% to not delay good code being made public. In the near future this can be fixed and raised again to 90%.
@MMathisLab MMathisLab merged commit 92c8b1f into AdaptiveMotorControlLab:main Apr 23, 2025
11 checks passed
CeliaBenquet pushed a commit to CeliaBenquet/CEBRA that referenced this pull request Apr 23, 2025
* Add multiobjective solver and regularized training (#783) * Add multiobjective solver and regularized training * Add example for multiobjective training * Add jacobian regularizer and SAM * update license headers * add api draft for multiobjective training * add all necessary modules to run the complete xcebra pipeline * add notebooks to reproduce xcebra pipeline * add first working notebook * add notebook with hybrid learning * add notebook with creation of synthetic data * add notebook with hybrid training * add plot with R2 for different parts of the embedding * add new API * update api wrapper with more checks and messages * add tests and notebook with new api * merge xcebra into attribution * separate xcebra dataset from cebra * some minor refactoring of cebra dataset * separate xcebra loader from cebra * remove xcebra distributions from cebra * minor refactoring with distributions * separate xcebra criterions from cebra * minor refactoring on criterion * separate xcebra models/criterions/layers from cebra * refactoring multiobjective * more refactoring... * separate xcebra solvers from cebra * more refactoring * move xcebra to its own package * move more files into xcebra package * more files and remove changes with the registry * remove unncessary import * add folder structure * move back distributions * add missing init * remove wrong init * make loader and dataset run with new imports * making it run! * make attribution run * Run pre-commit * move xcebra repo one level up * update gitignore and add __init__ from data * add init to distributions * add correct init for attribution pacakge * add correct init for model package * fix remaining imports * fix tests * add examples back to xcebra repo * update imports from graphs_xcebra * add setup.py to create a package * update imports of graph_xcebra * update notebooks * Formatting code for submission Co-authored-by: Rodrigo Gonzalez <gonlairo@gmail.com> * move test into xcebra * Add README * move distributions back to main package * clean up examples * adapt tests * Add LICENSE * add train/eval notebook again * add notebook with clean results * rm synthetic data * change name from xcebra to regcl * change names of modules and adapt imports * change name from graphs_xcebra to synthetic_data * Integrate into CEBRA * Fix remaining imports and make notebook runnable * Add dependencies, add version flag * Remove synthetic data files * reset dockerfile, move vmf * apply pre-commit * Update notice * add some docstrings * Apply license headers * add new scd notebook * add notebook with scd --------- Co-authored-by: Steffen Schneider <stes@hey.com> * Fix tests * bump version * update dockerfile * fix progress bar * remove outdated test * rename models * Apply fixes to pass ruff tests * Fix typos * Update license headers, fix additional ruff errors * remove unused comment * rename regcl in codebase * change regcl name in dockerfile * Improve attribution module * Fix imports name naming * add basic integration test * temp disable of binary check * Add legacy multiobjective model for backward compat * add synth import back in * Fix docstrings and type annot in cebra/models/jacobian_regularizer.py * add xcebra to tests * add missing cvxpy dep * fix docstrings * more docstrings to fix attr error * Improve build setup for docs * update pydata theme options * Add README for docs folder * Fix demo notebook build * Finish build setup * update git workflow * Move demo notebooks to CEBRA-demos repo See AdaptiveMotorControlLab/CEBRA-demos#28 * revert unneeded changes in solver * formatting in solver * further minimize solver diff * Revert unneeded updates to the solver * fix citation * fix docs build, missing refs * remove file dependency from xcebra int test * remove unneeded change in registry * update gitignore * update docs * exclude some assets * include binary file check again * add timeout to workflow * add timeout also to docs build * switch build back to sphinx for gh actions * pin sphinx version in setup.cfg * attempt workflow fix * attempt to fix build workflow * update to sphinx-build * fix build workflow * fix indent error * fix build system * revert demos to main * adapt workflow for testing * bump version to 0.6.0rc1 * format imports * docs writing * enable build on dev branch * fix some review comments * extend multiobjective docs * Set version to alpha * make tempdir platform independent * Remove ratinabox and ephysiopy as deps * Apply review comments * Update Makefile - setting coverage threshold to 80% to not delay good code being made public. In the near future this can be fixed and raised again to 90%. --------- Co-authored-by: Steffen Schneider <stes@hey.com> Co-authored-by: Steffen Schneider <steffen.schneider@helmholtz-munich.de> Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu>
MMathisLab added a commit that referenced this pull request May 23, 2025
* first proposal for batching in tranform method * first running version of padding with batched inference * start tests * add pad_before_transform to fit function and add support for convolutional models in _transform * remove print statements * first passing test * add support for hybrid models * rewrite transform in sklearn API * baseline version of a torch.Datset * move batching logic outside solver * move functionality to base file in solver and separate in functions * add test_select_model for single session * add checks and test for _process_batch * add test_select_model for multisession * make self.num_sessions compatible with single session training * improve test_batched_transform_singlesession * make it work with small batches * make test with multisession work * change to torch padding * add argument to sklearn api * add torch padding to _transform * convert to torch if numpy array as inputs * add distinction between pad with data and pad with zeros and modify test accordingly * differentiate between data padding and zero padding * remove float16 * change argument position * clean test * clean test * Fix warning * Improve modularity remove duplicate code and todos * Add tests to solver * Remove unused import in solver/utils * Fix test plot * Add some coverage * Fix save/load * Remove duplicate configure_for in multi dataset * Make save/load cleaner * Fix codespell errors * Fix docs compilation errors * Fix formatting * Fix extra docs errors * Fix offset in docs * Remove attribute ref * Add review updates * apply ruff auto-fixes * Concatenate last batches for batched inference (#200) * Concatenate last to batches for batched inference * Add test case * Fix linting errors in tests (#188) * apply auto-fixes * Fix linting errors in tests/ * Fix version check * Fix `scikit-learn` reference in conda environment files (#195) * Add support for new __sklearn_tags__ (#205) * Add support for new __sklearn_tags__ * fix inheritance order * Add more tests * fix added test * Update workflows to actions/setup-python@v5, actions/cache@v4 (#212) * Fix deprecation warning force_all_finite -> ensure_all_finite for sklearn>=1.6 (#206) * Add tests to check legacy model loading (#214) * Add improved goodness of fit implementation (#190) * Started implementing improved goodness of fit implementation * add tests and improve implementation * Fix examples * Fix docstring error * Handle batch size = None for goodness of fit computation * adapt GoF implementation * Fix docstring tests * Update docstring for goodness_of_fit_score Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * add annotations to goodness_of_fit_history Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * fix typo Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * improve err message Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * make numerical test less conversative * Add tests for exception handling * fix tests --------- Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * Support numpy 2, upgrade tests to support torch 2.6 (#221) * Drop numpy constraint * Implement workaround for pytables * better error message * pin numpy only for python 3.9 * update dependencies * Upgrade torch version * Fix based on python version * Add support for torch.load with weights_only=True * Implement safe loading for torch models starting in torch 2.6 * Fix windows specs * fix docstring * Revert changes to loading logic * Release 0.5.0rc1 (#189) * Make bump_version script runnable on MacOS * Bump version to 0.5.0rc1 * fix minor formatting issues * remove commented code --------- Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu> * Fix pypi action (#222) * force packaging upgrade to 24.2 for twine * Bump version to 0.5.0rc2 * remove universal compatibility option * revert tag * adapt files to new wheel name due to py3 * Update base.py (#224) This is a lazy solution to #223 * Change max consistency value to 100 instead of 99 (#227) * Change text consistency max from 99 to 100 * Update cebra/integrations/matplotlib.py --------- Co-authored-by: Mackenzie Mathis <mackenzie.mathis@epfl.ch> Co-authored-by: Steffen Schneider <steffen@bethgelab.org> * Update assets.py --> force check for parent dir (#230) Update assets.py - mkdir was failing in 0.5.0rc1; attempt to fix * User docs minor edit (#229) * user note added to usage.rst - link added * Update usage.rst - more detailed note on the effect of temp. * Update usage.rst - add in temp to demo model - testout put thanks @stes * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <stes@hey.com> * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <stes@hey.com> * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <stes@hey.com> --------- Co-authored-by: Steffen Schneider <stes@hey.com> * General Doc refresher (#232) * Update installation.rst - python 3.9+ * Update index.rst * Update figures.rst * Update index.rst -typo fix * Update usage.rst - update suggestion on data split * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <stes@hey.com> * Update usage.rst - indent error fixed * Update usage.rst - changed infoNCE to new GoF * Update usage.rst - finx numpy() doctest * Update usage.rst - small typo fix (label) * Update usage.rst --------- Co-authored-by: Steffen Schneider <stes@hey.com> * render plotly in our docs, show code/doc version (#231) * Update layout.html (#233) * Update conf.py (#234) - adding link to new notebook icon * Refactoring setup.cfg (#228) * Home page landing update (#235) * website refresh * v0.5.0 (#238) * Upgrade docs build (#241) * Improve build setup for docs * update pydata theme options * Add README for docs folder * Fix demo notebook build * Finish build setup * update git workflow * add timeout to workflow * add timeout also to docs build * switch build back to sphinx for gh actions * attempt to fix build workflow * update to sphinx-build * fix build workflow * fix indent error * fix build system * revert demos to main * increase timeout to 30 * Allow indexing of the cebra docs (#242) * Allow indexing of the cebra docs * Fix docs workflow * Fix broken docs coverage workflows (#246) * Add xCEBRA implementation (AISTATS 2025) (#225) * Add multiobjective solver and regularized training (#783) * Add multiobjective solver and regularized training * Add example for multiobjective training * Add jacobian regularizer and SAM * update license headers * add api draft for multiobjective training * add all necessary modules to run the complete xcebra pipeline * add notebooks to reproduce xcebra pipeline * add first working notebook * add notebook with hybrid learning * add notebook with creation of synthetic data * add notebook with hybrid training * add plot with R2 for different parts of the embedding * add new API * update api wrapper with more checks and messages * add tests and notebook with new api * merge xcebra into attribution * separate xcebra dataset from cebra * some minor refactoring of cebra dataset * separate xcebra loader from cebra * remove xcebra distributions from cebra * minor refactoring with distributions * separate xcebra criterions from cebra * minor refactoring on criterion * separate xcebra models/criterions/layers from cebra * refactoring multiobjective * more refactoring... * separate xcebra solvers from cebra * more refactoring * move xcebra to its own package * move more files into xcebra package * more files and remove changes with the registry * remove unncessary import * add folder structure * move back distributions * add missing init * remove wrong init * make loader and dataset run with new imports * making it run! * make attribution run * Run pre-commit * move xcebra repo one level up * update gitignore and add __init__ from data * add init to distributions * add correct init for attribution pacakge * add correct init for model package * fix remaining imports * fix tests * add examples back to xcebra repo * update imports from graphs_xcebra * add setup.py to create a package * update imports of graph_xcebra * update notebooks * Formatting code for submission Co-authored-by: Rodrigo Gonzalez <gonlairo@gmail.com> * move test into xcebra * Add README * move distributions back to main package * clean up examples * adapt tests * Add LICENSE * add train/eval notebook again * add notebook with clean results * rm synthetic data * change name from xcebra to regcl * change names of modules and adapt imports * change name from graphs_xcebra to synthetic_data * Integrate into CEBRA * Fix remaining imports and make notebook runnable * Add dependencies, add version flag * Remove synthetic data files * reset dockerfile, move vmf * apply pre-commit * Update notice * add some docstrings * Apply license headers * add new scd notebook * add notebook with scd --------- Co-authored-by: Steffen Schneider <stes@hey.com> * Fix tests * bump version * update dockerfile * fix progress bar * remove outdated test * rename models * Apply fixes to pass ruff tests * Fix typos * Update license headers, fix additional ruff errors * remove unused comment * rename regcl in codebase * change regcl name in dockerfile * Improve attribution module * Fix imports name naming * add basic integration test * temp disable of binary check * Add legacy multiobjective model for backward compat * add synth import back in * Fix docstrings and type annot in cebra/models/jacobian_regularizer.py * add xcebra to tests * add missing cvxpy dep * fix docstrings * more docstrings to fix attr error * Improve build setup for docs * update pydata theme options * Add README for docs folder * Fix demo notebook build * Finish build setup * update git workflow * Move demo notebooks to CEBRA-demos repo See AdaptiveMotorControlLab/CEBRA-demos#28 * revert unneeded changes in solver * formatting in solver * further minimize solver diff * Revert unneeded updates to the solver * fix citation * fix docs build, missing refs * remove file dependency from xcebra int test * remove unneeded change in registry * update gitignore * update docs * exclude some assets * include binary file check again * add timeout to workflow * add timeout also to docs build * switch build back to sphinx for gh actions * pin sphinx version in setup.cfg * attempt workflow fix * attempt to fix build workflow * update to sphinx-build * fix build workflow * fix indent error * fix build system * revert demos to main * adapt workflow for testing * bump version to 0.6.0rc1 * format imports * docs writing * enable build on dev branch * fix some review comments * extend multiobjective docs * Set version to alpha * make tempdir platform independent * Remove ratinabox and ephysiopy as deps * Apply review comments * Update Makefile - setting coverage threshold to 80% to not delay good code being made public. In the near future this can be fixed and raised again to 90%. --------- Co-authored-by: Steffen Schneider <stes@hey.com> Co-authored-by: Steffen Schneider <steffen.schneider@helmholtz-munich.de> Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu> * start tests * remove print statements * first passing test * move functionality to base file in solver and separate in functions * add test_select_model for multisession * remove float16 * Improve modularity remove duplicate code and todos * Add tests to solver * Fix save/load * Fix extra docs errors * Add review updates * apply ruff auto-fixes * fix linting errors * Run isort, ruff, yapf * Fix gaussian mixture dataset import * Fix all tests but xcebra tests * Fix pytorch API usage example * Make xCEBRA compatible with the batched inference & padding in solver * Add some tests on transform() with xCEBRA * Add some docstrings and typings and clean unnecessary changes * Implement review comments * Fix sklearn test * Add name in NOTE Co-authored-by: Steffen Schneider <steffen@bethgelab.org> * Implement reviews on tests and typing * Fix import errors * Add select_model to aux solvers * Fix docs error * Add tests on the private functions in base solver * Update tests and duplicate code based on review --------- Co-authored-by: Rodrigo <gonlairo@gmail.com> Co-authored-by: Steffen Schneider <stes@hey.com> Co-authored-by: Mackenzie Mathis <mathis@rowland.harvard.edu> Co-authored-by: Steffen Schneider <steffen.schneider@helmholtz-munich.de> Co-authored-by: Ícaro <icarosadero@proton.me> Co-authored-by: Mackenzie Mathis <mackenzie.mathis@epfl.ch> Co-authored-by: Steffen Schneider <steffen@bethgelab.org> Co-authored-by: Rodrigo González Laiz <31796689+gonlairo@users.noreply.github.com>
stes pushed a commit that referenced this pull request Jun 5, 2025
* start tests * remove print statements * first passing test * move functionality to base file in solver and separate in functions * add test_select_model for multisession * remove float16 * Improve modularity remove duplicate code and todos * Add tests to solver * Fix save/load * Fix extra docs errors * Add review updates * apply ruff auto-fixes * fix linting errors * Run isort, ruff, yapf * Fix gaussian mixture dataset import * Fix all tests but xcebra tests * Fix pytorch API usage example * Make xCEBRA compatible with the batched inference & padding in solver * Add some tests on transform() with xCEBRA * Add some docstrings and typings and clean unnecessary changes * Implement review comments * Fix sklearn test * Initial pass at integrating unifiedCEBRA * Add name in NOTE * Implement reviews on tests and typing * Fix import errors * Add select_model to aux solvers * Fix tests * Add mask tests * Fix docs error * Remove masking init() * Remove shuffled neurons in unified dataset * Remove extra datasets * Add tests on the private functions in base solver * Update tests and duplicate code based on review * Fix quantized_embedding_norm undefined when `normalize=False` (#249) * Fix tests * Adapt unified code to get_model method * Update mask.py add headers to new files * Update masking.py - header * Update test_data_masking.py - header * Implement review comments and fix typos * Fix docs errors * Remove np.int typing error * Fix docstring warning * Fix indentation docstrings * Implement review comments * Fix circular import and abstract method * Add maskedmixin to __all__ * Implement extra review comments * Change masking kwargs as tuple and not dict in sklearn impl * Add integrations/decoders.py * Fix typo * minor simplification in solver --------- Note, some comments in this PR overlap with #168 and #225 which were developed in parallel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants