Skip to content

Conversation

@YgLK
Copy link
Contributor

@YgLK YgLK commented Jul 11, 2025

What does this PR do?

Fixes #20973


📚 Documentation preview 📚: https://pytorch-lightning--20975.org.readthedocs.build/en/20975/

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Jul 11, 2025
@codecov
Copy link

codecov bot commented Jul 12, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 87%. Comparing base (a777069) to head (8c73bcb).
⚠️ Report is 3 commits behind head on master.

Additional details and impacted files
@@ Coverage Diff @@ ## master #20975 +/- ## ======================================= Coverage 87% 87% ======================================= Files 268 268 Lines 23324 23324 ======================================= + Hits 20316 20317 +1  + Misses 3008 3007 -1 
@SkafteNicki
Copy link
Collaborator

@Borda @YgLK here are a new version of https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/tests_fabric/accelerators/test_registry.py that includes relevant testing. In particular the test_registry_as_decorator fails on master but passes on this PR:

# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import pytest import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry from lightning.fabric.utilities.exceptions import MisconfigurationException class TestAccelerator(Accelerator): """Helper accelerator class for testing.""" def __init__(self, param1=None, param2=None): self.param1 = param1 self.param2 = param2 super().__init__() def setup_device(self, device: torch.device) -> None: pass def teardown(self) -> None: pass @staticmethod def parse_devices(devices): return devices @staticmethod def get_parallel_devices(devices): return ["foo"] * devices @staticmethod def auto_device_count(): return 3 @staticmethod def is_available(): return True def test_accelerator_registry_with_new_accelerator(): accelerator_name = "custom_accelerator" accelerator_description = "Custom Accelerator" class CustomAccelerator(Accelerator): def __init__(self, param1, param2): self.param1 = param1 self.param2 = param2 super().__init__() def setup_device(self, device: torch.device) -> None: pass def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: pass @staticmethod def parse_devices(devices): return devices @staticmethod def get_parallel_devices(devices): return ["foo"] * devices @staticmethod def auto_device_count(): return 3 @staticmethod def is_available(): return True ACCELERATOR_REGISTRY.register( accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123 ) assert accelerator_name in ACCELERATOR_REGISTRY assert ACCELERATOR_REGISTRY[accelerator_name]["description"] == accelerator_description assert ACCELERATOR_REGISTRY[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123} assert ACCELERATOR_REGISTRY[accelerator_name]["accelerator_name"] == accelerator_name assert isinstance(ACCELERATOR_REGISTRY.get(accelerator_name), CustomAccelerator) ACCELERATOR_REGISTRY.remove(accelerator_name) assert accelerator_name not in ACCELERATOR_REGISTRY def test_available_accelerators_in_registry(): assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"} def test_registry_as_decorator(): """Test that the registry can be used as a decorator.""" test_registry = _AcceleratorRegistry() # Test decorator usage @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) class DecoratorAccelerator(TestAccelerator): pass # Verify registration worked assert "test_decorator" in test_registry assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" # Test that we can instantiate the accelerator instance = test_registry.get("test_decorator") assert isinstance(instance, DecoratorAccelerator) assert instance.param1 == "value1" assert instance.param2 == 42 def test_registry_as_static_method(): """Test that the registry can be used as a static method call.""" test_registry = _AcceleratorRegistry() class StaticMethodAccelerator(TestAccelerator): pass # Test static method usage result = test_registry.register( "test_static", StaticMethodAccelerator, description="Test static method accelerator", param1="static_value", param2=100 ) # Verify registration worked assert "test_static" in test_registry assert test_registry["test_static"]["description"] == "Test static method accelerator" assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100} assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator assert test_registry["test_static"]["accelerator_name"] == "test_static" assert result == StaticMethodAccelerator # Should return the accelerator class # Test that we can instantiate the accelerator instance = test_registry.get("test_static") assert isinstance(instance, StaticMethodAccelerator) assert instance.param1 == "static_value" assert instance.param2 == 100 def test_registry_without_parameters(): """Test registration without init parameters.""" test_registry = _AcceleratorRegistry() class SimpleAccelerator(TestAccelerator): def __init__(self): super().__init__() test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") assert "simple" in test_registry assert test_registry["simple"]["description"] == "Simple accelerator" assert test_registry["simple"]["init_params"] == {} instance = test_registry.get("simple") assert isinstance(instance, SimpleAccelerator)
@YgLK
Copy link
Contributor Author

YgLK commented Aug 9, 2025

@SkafteNicki thanks for help - it looks great. I updated the tests 79c5c42

@Borda Borda merged commit 60883a9 into Lightning-AI:master Aug 12, 2025
115 of 122 checks passed
Borda pushed a commit that referenced this pull request Aug 13, 2025
* fix: remove extra parameter in accelerator registry decorator * tests: add registry decorator tests * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> (cherry picked from commit 60883a9)
Borda pushed a commit that referenced this pull request Aug 13, 2025
* fix: remove extra parameter in accelerator registry decorator * tests: add registry decorator tests * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> (cherry picked from commit 60883a9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fabric lightning.fabric.Fabric

5 participants