-
Couldn't load subscription status.
- Fork 560
Fix global_device_count(), local_device_count() for single process on CUDA #6022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits Select commit Hold shift + click to select a range
36c9d3d fixed it
vanbasten23 e1164be added the test
vanbasten23 61d4cda remove prints
vanbasten23 a025c23 fix up
vanbasten23 cb37524 replace gpu with cuda
vanbasten23 93f1dc5 fix up
vanbasten23 b6c58ca trigger another ci
vanbasten23 d724f64 fix the test_runtime.py test so it runs for GPU as well.
vanbasten23 1dd4261 remove pdb
vanbasten23 b39fbb2 skip a test temperarily.
vanbasten23 9a7b423 Fix the test XLAShardingTest.CreateTensorsData
vanbasten23 b90b342 skip another multithreading test for gpu
vanbasten23 ee243d0 remove prints
vanbasten23 3fa038f remove more prints
vanbasten23 9823269 remove global_runtime_device_count test case
vanbasten23 b10570b fix linter
vanbasten23 1832477 fix the broken tests
vanbasten23 d590e30 fix a build issue
vanbasten23 4d74168 fix linter
vanbasten23 b154538 fix build after pin update and a failing spmd test
vanbasten23 6bfdf19 fix linter
vanbasten23 c5e717e fix comments
vanbasten23 0291311 Incorporate the fix cl/601517680
vanbasten23 3c774fe add comment to the patch and fix linter
vanbasten23 862f7da add a single processing gpu test.
vanbasten23 5d64007 add print in spmd tests. Local test works but fail in the CI
vanbasten23 d59457f fix test and linter
vanbasten23 e8e1697 clean up prints
vanbasten23 db3d0c8 clean up another prints
vanbasten23 10e459b fix linter
vanbasten23 aee08df fix BasicShardingTest.test_2d_tensor_3d_mesh on cpu
vanbasten23 b3991a7 fix comments
vanbasten23 953d9b4 remove unwanted function
vanbasten23 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| // This patch is for https://github.com/openxla/xla/commit/ec0177de1748b4ebb0ecbd6f26043fdb1eb47d24. | ||
| // It can be removed in the next openXLA pin update after 01/26/2024. | ||
| diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc | ||
| index 0f1818be2..c181f3025 100644 | ||
| --- a/xla/service/gpu/gpu_executable.cc | ||
| +++ b/xla/service/gpu/gpu_executable.cc | ||
| @@ -382,9 +382,13 @@ absl::Status ExecuteThunks(const std::string& module_name, | ||
| } | ||
| } | ||
| | ||
| - // Maybe join a round of rendezvous after thunk initialization. | ||
| - TF_RETURN_IF_ERROR( | ||
| - MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); | ||
| + // Maybe join a round of rendezvous after thunk initialization. We do this | ||
| + // only in presence of collective cliques which means that we have collective | ||
| + // operations in the XLA operations that tend to cause deadlocks. | ||
| + if (!collective_cliques.empty()) { | ||
| + TF_RETURN_IF_ERROR( | ||
| + MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); | ||
| + } | ||
| | ||
| // Prepare parameters for thunks execution. | ||
| Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( | ||
| diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h | ||
| index 51a566b8f..94bab421f 100644 | ||
| --- a/xla/service/gpu/thunk.h | ||
| +++ b/xla/service/gpu/thunk.h | ||
| @@ -175,6 +175,8 @@ class Thunk { | ||
| absl::StatusOr<NcclComm::Lock> GetComm(const NcclCliqueKey& clique_key, | ||
| int32_t rank) const; | ||
| | ||
| + bool empty() const { return cliques_map_.empty(); } | ||
| + | ||
| private: | ||
| CliquesMap cliques_map_; | ||
| }; | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import concurrent.futures | ||
| import itertools | ||
| import os | ||
| import queue | ||
| import requests | ||
| import unittest | ||
| import subprocess | ||
| | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch_xla | ||
| import torch_xla.core.xla_env_vars as xenv | ||
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.distributed.xla_multiprocessing as xmp | ||
| from torch_xla import runtime as xr | ||
| from torch_xla._internal import pjrt | ||
| from absl.testing import absltest, parameterized | ||
| | ||
| | ||
| @unittest.skipIf(xr.device_type() != "CUDA", | ||
| f"GPU tests should only run on GPU devices.") | ||
| class TestExperimentalSingleProcPjrtGpu(parameterized.TestCase): | ||
| | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| command = 'nvidia-smi --list-gpus | wc -l' | ||
| result = subprocess.run( | ||
| command, | ||
| capture_output=True, | ||
| shell=True, | ||
| check=True, | ||
| text=True, | ||
| ) | ||
| cls.num_cuda_devices = int(result.stdout) | ||
| | ||
| def test_num_local_devices(self): | ||
| self.assertLen(xm.get_xla_supported_devices(), | ||
| xr.addressable_device_count()) | ||
| self.assertEqual(self.num_cuda_devices, xr.addressable_device_count()) | ||
| | ||
| def test_num_global_devices(self): | ||
| self.assertLen(torch_xla._XLAC._xla_get_all_devices(), | ||
| xr.global_device_count()) | ||
| self.assertEqual(self.num_cuda_devices, xr.global_device_count()) | ||
| | ||
| | ||
| if __name__ == '__main__': | ||
| absltest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment with the correspond git commit hash in openXLA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done