- Notifications
You must be signed in to change notification settings - Fork 3.6k
fix: remove extra parameter in accelerator registry decorator #20975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: remove extra parameter in accelerator registry decorator #20975
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@ ## master #20975 +/- ## ======================================= Coverage 87% 87% ======================================= Files 268 268 Lines 23324 23324 ======================================= + Hits 20316 20317 +1 + Misses 3008 3007 -1 |
| @Borda @YgLK here are a new version of https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/tests_fabric/accelerators/test_registry.py that includes relevant testing. In particular the # Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any import pytest import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator from lightning.fabric.accelerators.registry import _AcceleratorRegistry from lightning.fabric.utilities.exceptions import MisconfigurationException class TestAccelerator(Accelerator): """Helper accelerator class for testing.""" def __init__(self, param1=None, param2=None): self.param1 = param1 self.param2 = param2 super().__init__() def setup_device(self, device: torch.device) -> None: pass def teardown(self) -> None: pass @staticmethod def parse_devices(devices): return devices @staticmethod def get_parallel_devices(devices): return ["foo"] * devices @staticmethod def auto_device_count(): return 3 @staticmethod def is_available(): return True def test_accelerator_registry_with_new_accelerator(): accelerator_name = "custom_accelerator" accelerator_description = "Custom Accelerator" class CustomAccelerator(Accelerator): def __init__(self, param1, param2): self.param1 = param1 self.param2 = param2 super().__init__() def setup_device(self, device: torch.device) -> None: pass def get_device_stats(self, device: torch.device) -> dict[str, Any]: pass def teardown(self) -> None: pass @staticmethod def parse_devices(devices): return devices @staticmethod def get_parallel_devices(devices): return ["foo"] * devices @staticmethod def auto_device_count(): return 3 @staticmethod def is_available(): return True ACCELERATOR_REGISTRY.register( accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123 ) assert accelerator_name in ACCELERATOR_REGISTRY assert ACCELERATOR_REGISTRY[accelerator_name]["description"] == accelerator_description assert ACCELERATOR_REGISTRY[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123} assert ACCELERATOR_REGISTRY[accelerator_name]["accelerator_name"] == accelerator_name assert isinstance(ACCELERATOR_REGISTRY.get(accelerator_name), CustomAccelerator) ACCELERATOR_REGISTRY.remove(accelerator_name) assert accelerator_name not in ACCELERATOR_REGISTRY def test_available_accelerators_in_registry(): assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"} def test_registry_as_decorator(): """Test that the registry can be used as a decorator.""" test_registry = _AcceleratorRegistry() # Test decorator usage @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) class DecoratorAccelerator(TestAccelerator): pass # Verify registration worked assert "test_decorator" in test_registry assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" # Test that we can instantiate the accelerator instance = test_registry.get("test_decorator") assert isinstance(instance, DecoratorAccelerator) assert instance.param1 == "value1" assert instance.param2 == 42 def test_registry_as_static_method(): """Test that the registry can be used as a static method call.""" test_registry = _AcceleratorRegistry() class StaticMethodAccelerator(TestAccelerator): pass # Test static method usage result = test_registry.register( "test_static", StaticMethodAccelerator, description="Test static method accelerator", param1="static_value", param2=100 ) # Verify registration worked assert "test_static" in test_registry assert test_registry["test_static"]["description"] == "Test static method accelerator" assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100} assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator assert test_registry["test_static"]["accelerator_name"] == "test_static" assert result == StaticMethodAccelerator # Should return the accelerator class # Test that we can instantiate the accelerator instance = test_registry.get("test_static") assert isinstance(instance, StaticMethodAccelerator) assert instance.param1 == "static_value" assert instance.param2 == 100 def test_registry_without_parameters(): """Test registration without init parameters.""" test_registry = _AcceleratorRegistry() class SimpleAccelerator(TestAccelerator): def __init__(self): super().__init__() test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") assert "simple" in test_registry assert test_registry["simple"]["description"] == "Simple accelerator" assert test_registry["simple"]["init_params"] == {} instance = test_registry.get("simple") assert isinstance(instance, SimpleAccelerator) |
| @SkafteNicki thanks for help - it looks great. I updated the tests 79c5c42 |
* fix: remove extra parameter in accelerator registry decorator * tests: add registry decorator tests * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> (cherry picked from commit 60883a9)
* fix: remove extra parameter in accelerator registry decorator * tests: add registry decorator tests * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> (cherry picked from commit 60883a9)
What does this PR do?
Fixes #20973
📚 Documentation preview 📚: https://pytorch-lightning--20975.org.readthedocs.build/en/20975/