Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

Summary:
This pull request introduces a helper for gmm_backward. I'm still debuting if we need to make gmm as a autograd.function given we will do manual back-propagation in Mixtral.

Test Plan:
python test/test_gmm.py

@alanwaketan alanwaketan self-assigned this May 30, 2024

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_backward(self):
self._init_test_cases()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, you might need a met.clear_all() here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me fix it in the next PR. Don't want to waste CI cycles.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "manual backprop in mixtral"?

@alanwaketan
Copy link
Collaborator Author

Skip GPU tests to move fast.

@alanwaketan
Copy link
Collaborator Author

What do you mean by "manual backprop in mixtral"?

Just we need to override the MoE back prop to accommondate gmm backward and manual sharding. You will know what that means once the code is ready. Thanks for approving this change.

@alanwaketan alanwaketan merged commit c96c95a into master May 30, 2024
@alanwaketan alanwaketan deleted the alanwaktan/tgmm3 branch May 30, 2024 03:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants