Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Apr 10, 2024

Summary:
This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly.

Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding

@alanwaketan alanwaketan requested review from jonb377 and yeounoh April 10, 2024 23:08
@alanwaketan alanwaketan self-assigned this Apr 10, 2024
Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I've been wondering what manual sharding is for.

// The the returned tensors will be in 1:1 correspondence with the `devices`
// vector, so the `i`th result will belong on the `i`th device.
// the `tile_assignment`; MANUAL sharding result in shards where only the
// first device holds the full data; the returned tensor shards vector is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only the first device holds the full data

Is this by definition of manual sharding?

Copy link
Contributor

@yeounoh yeounoh Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not by definition, but by our implementation choice. A more proper example would be a list of tensors (DTensor), where each tensor is an individual full shard.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yeounoh Will that be replicated then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our offline discussion, abstain from manual sharding on input data.

result.reserve(cpu_shards.size() / shards_per_tensor);
for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) {
std::vector<at::Tensor> cpu_shards =
XlaDataToTensors(WrapXlaData(shard_handles), element_types);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling XlaDataToTensors on each tensor individually will slow down d2h transfers for async checkpointing, since PjRt won't be able to fully utilize transfer parallelization.

Do we expect manually-sharded tensors to contain actual device data generally, or will they usually be IR? If just IR, maybe we can add an assertion to prevent access here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather keep it functional for both cases -- shouldn't it be asynchronous anyway, not blocking the actual training run?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. I was not aware this performance optimization...

} else if ((sharding.type() == xla::OpSharding::MANUAL)) {
// Just put the full tensor on the first device.
shards[0] = tensor;
shards.resize(1);
Copy link
Collaborator

@jonb377 jonb377 Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for a compuatation, since we need to feed each device some input data?

e.g. based on your unit test, what happens if we run:

 x = torch.randn(3, 2) xx = x.to(xm.xla_device()) # xx is device data xt = xs._mark_manual_sharding(xx) ones = torch.ones(3, 2).to(xm.xla_device()) # ones is replicated to all devices print(xt + ones) # What will happen here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA should assume that xt is sharded manually, so expected to be plicated as well. The purpose of MANUAL is to support custom kernel and prevent XLA to override the manual sharding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I would expect it behaves as a single device. Let me double check as well.

xt = xs._mark_manual_sharding(xx)

hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt.global_tensor])
self.assertIn('parameter(0), sharding={manual}', hlo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I leave the correctness review of distributed checkpointing with manual sharding to @jonb377 and his unit tests.


self.assertEqual(id(mesh), id(expected_mesh))

def test__mark_manual_sharding(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. even though it's testing the _ prefixed api, let's keep it as test_mark_manual_sharding

@alanwaketan
Copy link
Collaborator Author

@alanwaketan
Copy link
Collaborator Author

@alanwaketan
Copy link
Collaborator Author

All tests passed. I'm going to merge it. Let me know if I need to follow up on anything.

@alanwaketan alanwaketan merged commit e5513ff into master Apr 12, 2024
lausannel pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary: This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly. Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
baoleai pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary: This pull request makes SPMD support the manual sharding type via a new private API called: _mark_manual_sharding. I don't expect users will need to call this function explicitly. Besides adding support for the sharding annotation, we also need to define the behavior of the data shards. For data, the current behavior is error out. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test__mark_manual_sharding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants