-
Couldn't load subscription status.
- Fork 31k
[TPU] Support PyTorch/XLA FSDP via SPMD #28949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Can HF folks point me on how to add test case in this case and also how to update the documentation? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall! We might want to add a small test, it can be done in a followup PR.
Pinging @muellerzr for a second look!
src/transformers/trainer.py Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not super fan of super short names but seems common in trainer!
| Tests should be added in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @ArthurZucker hinted at, we now don't handle things like this in the trainer directly. I would rather see this code over in accelerate which we can then bring into Trainer automatically since it relies on it for preparation. Especially as this deals with the dataloaders. Would that be possible please! :)
src/transformers/trainer.py Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this easier by importing FSDPv2 as FSDP instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask what's the benefits of doing so?
src/transformers/trainer.py Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And then leave the check for down here on what to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shard_output is not used by FSDPv1. Shouldn't we guard that with the flag too?
Can you elaborate it a bit more? I can move the |
Speaking of adding tests, what should I test? I mean do you have TPU CI? |
src/transformers/trainer.py Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed a bug here. cc @ArthurZucker @jonb377
| The test failures don't seem to be related. I tried rebasing as well. |
| Thanks @ArthurZucker and @muellerzr for approving the change. |
| The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| It's all green. Can HF folks help with landing the PR? Appreciate it. |
| I can merge :) Thanks for adding this support @alanwaketan! |
What does this PR do?
Summary:
This is the first attempt to enable FSDP via SPMD (FSDPv2) on PyTorch/XLA model.
More information about FSDPv2 can be found here:
Besides the initial implementation of FSDPv2 in r2.2, this change will also requires the following changes in PyTorch/XLA:
Therefore, it will only be compatible with the nightly builds.
Example use cases:
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true }Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @younesbelkada