Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Dec 4, 2023

This PR fixes

  • global_device_count() so that it returns all GPU devices across all processes/hosts. The fix also works for multi-host case.
  • local_device_count()
    for single process on CUDA so that it returns all GPU devices on the current host.

Before this PR, both APIs always return 1 as reported in this issue.).

global_runtime_device_count is not fixed since it seems it's only used in spmd case and it's been fixed in another pr.

Test:

  • PJRT_DEVICE=CUDA python pytorch/xla/test/pjrt/test_runtime.py
  • PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node 2 pytorch/xla/test/pjrt/test_torchrun.py

Note: here is the behavior of torch.cuda.device_count() on multi-host case

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

std::optional<std::set<int>> allowed_devices;
if (global_world_size > 1) {
allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do you still need to make_optional here, or can you directly assign? e.g. allowed_devices = std::set{local_process_rank}

@vanbasten23
Copy link
Collaborator Author

Thanks for the review!

@vanbasten23 vanbasten23 requested a review from yeounoh December 5, 2023 00:19
@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Dec 6, 2023

A few tests are failing:

  1. //test/cpp:test_replication https://btx-internal.corp.google.com/invocations/a14fc207-5237-47cb-81c7-fbbba96b800a

Failed at torch_xla::runtime::GetComputationClient()->Compile(std::move(instances)); due to timing out.

The test is doing an all-reduce (xla::CrossReplicaSum) on a single process with 4 devices, should we disable this test on GPU?

  1. test (main.TestParallelTensorMNIST)
    No link. Fails due to E external/xla/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:

I think we can disable this test because it tests the class DataParallel which Enable the execution of a model network in replicated mode using threads. But for GPU, it uses process instead of thread.

  1. //test/cpp:test_xla_sharding https://source.cloud.google.com/results/invocations/fd070252-3aa9-4196-8141-a750c75a1526
    XLAShardingTest.CreateTensorsData
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
std::optional<std::set<int>> allowed_devices;
if (global_world_size > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Synced offline - this is great for single-host single-process development, but in cases where there is a single process per host, this would break in a multihost environment. Outside of SPMD, I'm not aware of a use case for a multihost environment with a single process per host (cc @JackCaoG)

Since we don't officially support SPMD on GPU at the moment, this looks fine to me for now. Once we decide on the right entrypoint for SPMD, we'll need to revisit this.

@vanbasten23 vanbasten23 force-pushed the fix_global_runtime_device_count branch from 54bad35 to d0faac5 Compare January 4, 2024 19:27
@vanbasten23 vanbasten23 changed the title Fix xr.global_runtime_device_count for single process on CUDA Fix global_device_count(), local_device_count for single process on CUDA Jan 9, 2024
@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Jan 9, 2024

hi @JackCaoG , this PR fixes
Fix local_device_count(), global_device_count() for single processing case so that they return the total number of devices on the host. Currently the test test/test_operations.py TestWaitDeviceOps.test_wait_device_ops fails with OOM (I summarized here). I wonder if you some pointers on the test failure.

@JackCaoG
Copy link
Collaborator

@vanbasten23 can you rebase and rerun the CI?

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once you fix the formatting. Thanks!



@unittest.skipIf(xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but IMO we should just delete these tests. Do we support DataParallel anymore @JackCaoG?

@vanbasten23 vanbasten23 force-pushed the fix_global_runtime_device_count branch from fe03a80 to b98ef93 Compare January 16, 2024 20:14
@vanbasten23

This comment was marked as outdated.

@vanbasten23 vanbasten23 changed the title Fix global_device_count(), local_device_count for single process on CUDA Fix global_device_count(), local_device_count() for single process on CUDA Jan 17, 2024
@vanbasten23 vanbasten23 force-pushed the fix_global_runtime_device_count branch from cbeabb8 to 6c2b64f Compare January 19, 2024 22:08
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
std::optional<std::set<int>> allowed_devices;
bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditioning on SPMD mode here could cause issues using xr.use_spmd() after the runtime has been initialized.

Is it correct to say that allowed_devices is only needed in the MP case? If so, can we invert the logic to check for MP using one of the env vars instead of checking for SPMD mode?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to call xr.use_spmd() after the runtime is initialized? In any case, I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices (which should be compatible with SPMD)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK there's not a strong use case at the moment, but for example our unit tests will check xr.global_runtime_device_count() before calling xr.use_spmd(). Keeping the runtime independent of SPMD mode was something we wanted to maintain, cc @yeounoh

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point!

Conditioning on SPMD mode here could cause issues using xr.use_spmd() after the runtime has been initialized.

This seems to be a downside of using xr.use_spmd() as opposed to a env flag XLA_USE_SPMD=1. With the latter, it's less flexible but less error-prone. It guarantee we'll use spmd mode at the beginning. With the former, it may also impact other SPMD special cases: user does something pytorch ops, then call xr.use_spmd(), then continue to do something else.

can we invert the logic to check for MP using one of the env vars instead of checking for SPMD mode?
I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices

I'm thinking about the case where the user has 2 GPU machines and she wants to use 1 GPU device on each machine and to do multi-host training. In that case (multi-host-single-process), each process has access to all devices and I guess the user can still do multi-host training

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can also assume that if LOCAL_WORLD_SIZE=1, then we can use all of the devices

Perhaps we also need to check GLOBAL_WORLD_SIZE:

if LOCAL_WORLD_SIZE==1: if GLOBAL_WORLD_SIZE>1: # multi-host-single-process initialize coordinator service else: # single-host-single-process do nothing else: multi-process for single-host and multi-host initialize coordinator service allowd_devices={current_device} 

std::unique_ptr<XlaCoordinator> SetKeyValueCallback(
int global_process_rank, int global_world_size,
std::unique_ptr<XlaCoordinator> coordinator,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the coordinator as input here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it get the DistributedRuntimeClient and later create the kv_store below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it's being recreated on L60 - should we just make this function return the new value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

will-cromar added a commit that referenced this pull request Jan 19, 2024
@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Jan 22, 2024

It looks like a bunch of tests are failing with error

+ PJRT_DEVICE=CUDA + run_coverage /tmp/pytorch/xla/test/test_autocast.py + '[' 0 '!=' 0 ']' + python3 /tmp/pytorch/xla/test/test_autocast.py WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1705733143.143549 996 se_gpu_pjrt_client.cc:751] Using BFC allocator. /opt/conda/lib/python3.8/site-packages/torch_xla-2.2.0+gitf16e0c7-py3.8-linux-x86_64.egg/torch_xla/core/xla_model.py:101: UserWarning: `devkind` argument is deprecated and will be removed in a future release. warnings.warn("`devkind` argument is deprecated and will be removed in a " test_autocast_banned (__main__.TestAutocastCuda) ... ok test_autocast_linalg_fp16 (__main__.TestAutocastCuda) ... 2024-01-20 06:45:53.672768: E external/xla/xla/service/rendezvous.cc:33] This thread has been waiting for 10 seconds and may be stuck: 2024-01-20 06:46:23.673040: E external/xla/xla/service/rendezvous.cc:43] Termination timeout of 30 seconds exceeded. Exiting to ensure a consistent program state. Error: The operation was canceled. 

One of the examples is PJRT_DEVICE=CUDA python pytorch/xla/test/test_autocast.py

The error doesn't exist on the current master branch (01/19, after the openxla pin update). Also, the error doesn't exist on the feature branch before the pin update: #6346. Probably something happened during the pin update.

@vanbasten23 vanbasten23 force-pushed the fix_global_runtime_device_count branch from 960d609 to aee08df Compare February 2, 2024 14:58
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, thanks Xiongfei!

Comment on lines 211 to 213
# if self.n_devices>=4, mesh=(2, 2)
# if self.n_devices>=2, mesh=(2,1)
# if self.n_devices=1, mesh=(1,1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for generalizing these tests!

Could we change these comments to e.g. # if self.n_devices==4, mesh=(2, 2)? Other device counts will have different meshes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Done.

Comment on lines 168 to 156
if (local_world_size == 1) {
if (global_world_size > 1) {
coordinator = SetGpuClientKVCallBack(global_process_rank,
global_world_size, kv_store);
}
} else {
allowed_devices = std::set{local_process_rank};
coordinator = SetGpuClientKVCallBack(global_process_rank,
global_world_size, kv_store);
}

std::shared_ptr<xla::KeyValueStoreInterface> kv_store;
if (global_world_size > 1) {
// Use the distributed key-value store from DistributedRuntimeClient.
coordinator = std::make_unique<XlaCoordinator>(
global_process_rank, global_world_size, master_addr, port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
coordinator->GetClient();
kv_store = xla::GetDistributedKeyValueStore(distributed_client,
/*key_prefix=*/"gpu:");
}
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could simplify the logic here some. We want to restrict allowed_devices if local_world_size > 1 and create the coordinator if global_world_size > 1. I'm assuming local_world_size > 1 => global_world_size > 1, would this be equivalent?

if (local_world_size > 1) { allowed_devices = std::set{local_process_rank}; } if (global_world_size > 1) { // We can keep the old initialization block here and remove `SetGpuClientKVCallBack` }
@vanbasten23 vanbasten23 merged commit 8fc8d57 into master Feb 3, 2024
@vanbasten23
Copy link
Collaborator Author

Thanks for the review!

TF_VLOG(INFO) << "OpSharding (ShardingType: " << sharding_type << "):\n"
<< sharding.DebugString();
<< sharding.DebugString()
<< ", sharding.type()=" << sharding.type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DebugString should include the type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it doesn't. The debugString is empty in that case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

6 participants