-
Couldn't load subscription status.
- Fork 560
[SPMD] Mesh to support custom device order. #4162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ), "PyTorch/XLA SPMD requires PJRT_DEVICE={CPU, TPU}, GPU is currently not supported." | ||
| ) | ||
| @unittest.skipIf(not using_pjrt() or xm.get_xla_supported_devices("GPU"), | ||
| f"Requires PJRT_DEVICE set to `TPU` or `CPU`.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@will-cromar I think PJRT-GPU single core is ready now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's blocked from the our SPMD side, once we support TPU, the transition should be easier to GPU -- maybe sometime next year once we are done with the basic/core SPMD features?
| Args: | ||
| device_ids (Union[np.ndarray, List]): A raveled list of devices (IDs) in a custom order. The list is reshaped | ||
| to an `mesh_shape` array, filling the elements using C-like index order. For example, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the example lol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok it is below, you might want to change the wording here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
970536f to 716865c Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
| mesh_shape (Tuple[Union[int, None]]): A int tuple describing the logical topology | ||
| of the device mesh, and each element describes the number of devices in | ||
| the corresponding axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like mesh_shape can be removed here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch :)
| | ||
| def test_custom_tile_assignment(self): | ||
| xt = torch.randn(10, 20).to(device=xm.xla_device()) | ||
| mesh_shape = (1, self.n_devices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the tests have all devices mapped to a single axis - is there anything stopping us from using e.g. mesh_shape = (2, self.n_devices / 2)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but for the unit testing a flat mesh is easier to work with since we don't know how many devices we would have (e.g., for CPU, we will have 1).
| def __init__(self, | ||
| device_ids: Union[np.ndarray, List], | ||
| mesh_shape: Tuple[int, ...], | ||
| axis_names: Tuple[str, ...] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious - how will axis_names be used long-term? Is it just for annotating the mesh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, mesh axis annotation is useful since it makes the annotation logic more readable. We can also build a partitioning rule based on the axis name, instead of int indices.
716865c to 15a30e4 Compare |
|
3d6bc93 to eea8e9c Compare eea8e9c to 234871f Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
This implements
Meshclass from #3871 , to support custom device order in logical XLA device mesh topology.