-
Couldn't load subscription status.
- Fork 560
Fix global_device_count(), local_device_count() for single process on CUDA #6022
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
| std::optional<std::set<int>> allowed_devices; | ||
| if (global_world_size > 1) { | ||
| allowed_devices = | ||
| std::make_optional<std::set<int>>(std::set{local_process_rank}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do you still need to make_optional here, or can you directly assign? e.g. allowed_devices = std::set{local_process_rank}
| Thanks for the review! |
| A few tests are failing:
Failed at The test is doing an all-reduce (
I think we can disable this test because it tests the class
|
| auto allowed_devices = | ||
| std::make_optional<std::set<int>>(std::set{local_process_rank}); | ||
| std::optional<std::set<int>> allowed_devices; | ||
| if (global_world_size > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synced offline - this is great for single-host single-process development, but in cases where there is a single process per host, this would break in a multihost environment. Outside of SPMD, I'm not aware of a use case for a multihost environment with a single process per host (cc @JackCaoG)
Since we don't officially support SPMD on GPU at the moment, this looks fine to me for now. Once we decide on the right entrypoint for SPMD, we'll need to revisit this.
54bad35 to d0faac5 Compare | hi @JackCaoG , this PR fixes |
| @vanbasten23 can you rebase and rerun the CI? |
3292787 to fe03a80 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once you fix the formatting. Thanks!
test/test_operations.py Outdated
| | ||
| | ||
| @unittest.skipIf(xr.device_type() == 'CUDA', | ||
| 'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this PR, but IMO we should just delete these tests. Do we support DataParallel anymore @JackCaoG?
fe03a80 to b98ef93 Compare This comment was marked as outdated.
This comment was marked as outdated.
cbeabb8 to 6c2b64f Compare | kv_store = xla::GetDistributedKeyValueStore(distributed_client, | ||
| /*key_prefix=*/"gpu:"); | ||
| std::optional<std::set<int>> allowed_devices; | ||
| bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditioning on SPMD mode here could cause issues using xr.use_spmd() after the runtime has been initialized.
Is it correct to say that allowed_devices is only needed in the MP case? If so, can we invert the logic to check for MP using one of the env vars instead of checking for SPMD mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there ever a reason to call xr.use_spmd() after the runtime is initialized? In any case, I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices (which should be compatible with SPMD)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK there's not a strong use case at the moment, but for example our unit tests will check xr.global_runtime_device_count() before calling xr.use_spmd(). Keeping the runtime independent of SPMD mode was something we wanted to maintain, cc @yeounoh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point!
Conditioning on SPMD mode here could cause issues using xr.use_spmd() after the runtime has been initialized.
This seems to be a downside of using xr.use_spmd() as opposed to a env flag XLA_USE_SPMD=1. With the latter, it's less flexible but less error-prone. It guarantee we'll use spmd mode at the beginning. With the former, it may also impact other SPMD special cases: user does something pytorch ops, then call xr.use_spmd(), then continue to do something else.
can we invert the logic to check for MP using one of the env vars instead of checking for SPMD mode?
I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices
I'm thinking about the case where the user has 2 GPU machines and she wants to use 1 GPU device on each machine and to do multi-host training. In that case (multi-host-single-process), each process has access to all devices and I guess the user can still do multi-host training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices
Perhaps we also need to check GLOBAL_WORLD_SIZE:
if LOCAL_WORLD_SIZE==1: if GLOBAL_WORLD_SIZE>1: # multi-host-single-process initialize coordinator service else: # single-host-single-process do nothing else: multi-process for single-host and multi-host initialize coordinator service allowd_devices={current_device} | | ||
| std::unique_ptr<XlaCoordinator> SetKeyValueCallback( | ||
| int global_process_rank, int global_world_size, | ||
| std::unique_ptr<XlaCoordinator> coordinator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need the coordinator as input here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need it get the DistributedRuntimeClient and later create the kv_store below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like it's being recreated on L60 - should we just make this function return the new value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| It looks like a bunch of tests are failing with error One of the examples is The error doesn't exist on the current master branch (01/19, after the openxla pin update). Also, the error doesn't exist on the feature branch before the pin update: #6346. Probably something happened during the pin update. |
960d609 to aee08df Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, thanks Xiongfei!
test/spmd/test_xla_sharding.py Outdated
| # if self.n_devices>=4, mesh=(2, 2) | ||
| # if self.n_devices>=2, mesh=(2,1) | ||
| # if self.n_devices=1, mesh=(1,1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for generalizing these tests!
Could we change these comments to e.g. # if self.n_devices==4, mesh=(2, 2)? Other device counts will have different meshes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Done.
| if (local_world_size == 1) { | ||
| if (global_world_size > 1) { | ||
| coordinator = SetGpuClientKVCallBack(global_process_rank, | ||
| global_world_size, kv_store); | ||
| } | ||
| } else { | ||
| allowed_devices = std::set{local_process_rank}; | ||
| coordinator = SetGpuClientKVCallBack(global_process_rank, | ||
| global_world_size, kv_store); | ||
| } | ||
| | ||
| std::shared_ptr<xla::KeyValueStoreInterface> kv_store; | ||
| if (global_world_size > 1) { | ||
| // Use the distributed key-value store from DistributedRuntimeClient. | ||
| coordinator = std::make_unique<XlaCoordinator>( | ||
| global_process_rank, global_world_size, master_addr, port); | ||
| std::shared_ptr<xla::DistributedRuntimeClient> distributed_client = | ||
| coordinator->GetClient(); | ||
| kv_store = xla::GetDistributedKeyValueStore(distributed_client, | ||
| /*key_prefix=*/"gpu:"); | ||
| } | ||
| TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" | ||
| << global_process_rank << ", num_nodes=" << global_world_size; | ||
| |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could simplify the logic here some. We want to restrict allowed_devices if local_world_size > 1 and create the coordinator if global_world_size > 1. I'm assuming local_world_size > 1 => global_world_size > 1, would this be equivalent?
if (local_world_size > 1) { allowed_devices = std::set{local_process_rank}; } if (global_world_size > 1) { // We can keep the old initialization block here and remove `SetGpuClientKVCallBack` }| Thanks for the review! |
| TF_VLOG(INFO) << "OpSharding (ShardingType: " << sharding_type << "):\n" | ||
| << sharding.DebugString(); | ||
| << sharding.DebugString() | ||
| << ", sharding.type()=" << sharding.type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DebugString should include the type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it doesn't. The debugString is empty in that case.
This PR fixes
for single process on CUDA so that it returns all GPU devices on the current host.
Before this PR, both APIs always return 1 as reported in this issue.).
global_runtime_device_count is not fixed since it seems it's only used in spmd case and it's been fixed in another pr.
Test:
Note: here is the behavior of torch.cuda.device_count() on multi-host case