-
Couldn't load subscription status.
- Fork 560
Introduce virtual device #4091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce virtual device #4091
Conversation
a7274ef to 9976bb2 Compare 9976bb2 to 2b5dd6e Compare | Let's make sure that we cover the explict sharded cases, where we want to avoid the initial unpartitioned data transfer. We will have to double-check, but |
| Notes after chat with Yeounoh:
|
|
This is the only entry for the transfer data to device. |
6a326ec to 32ff407 Compare | Remaining implementation details before I can start testing:
|
f64c1a5 to 01fa14d Compare df56ba8 to ede9395 Compare 17cf0a0 to 5b7ac6f Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @steventk-g 👍
5b7ac6f to 1a43d3a Compare
Changes in this PR
XLA_USE_SPMD_xla_mark_sharding. The re-downloading path is preserved so thatXLA_USE_SPMD=0still works as well.xla:0fromxm.xla_device(). At this point, users should expect all tensors to be treated as if they are on the virtual device when SPMD is enabled.