Skip to content

Conversation

@zpcore
Copy link
Member

@zpcore zpcore commented Jan 30, 2025

Resolves #8633

@zpcore zpcore force-pushed the piz/autograde_trace branch from f82f373 to a8c8f47 Compare January 31, 2025 10:31
@zpcore zpcore changed the title backward with spmd issue Dynamo/AOTAutograd traceable flash attention Jan 31, 2025
@zpcore zpcore requested a review from tengyifei January 31, 2025 10:38
Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great

@zpcore zpcore merged commit 9ae017e into master Feb 1, 2025
12 checks passed
@zpcore zpcore deleted the piz/autograde_trace branch February 1, 2025 04:22
tengyifei added a commit to AI-Hypercomputer/torchprime that referenced this pull request Mar 17, 2025
We replace the `for` loop in both Llama and Mixtral with an equivalent `HomogenousSequential` layer, which can be either run a for loop or use `torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off without cluttering the modeling code. I also adjusted Mixtral slightly so that we can even run `scan` in Mixtral with its static MoE implementation. Scanning over GMM on the other hand won't work until GMM forward/backward is wrapped in a custom op similar to pytorch/xla#8654. Test: added unit test. Next PR will change the trainer to apply scan.
tengyifei added a commit to AI-Hypercomputer/torchprime that referenced this pull request Mar 18, 2025
* Make models amenable to scan We replace the `for` loop in both Llama and Mixtral with an equivalent `HomogenousSequential` layer, which can be either run a for loop or use `torch_xla`'s scan operator. This is a clean-ish way to turn scan on/off without cluttering the modeling code. I also adjusted Mixtral slightly so that we can even run `scan` in Mixtral with its static MoE implementation. Scanning over GMM on the other hand won't work until GMM forward/backward is wrapped in a custom op similar to pytorch/xla#8654. Test: added unit test. Next PR will change the trainer to apply scan. * Address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants