Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

Summary:
This pull request make GMM as a torch.autograd.Function such that we can use torch.autograd.backward instead of manual backpropagation.

Test Plan:
python test/test_gmm.py

@alanwaketan alanwaketan self-assigned this May 30, 2024
# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need TPU version check here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, good question.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v2 is pretty happy on the tree.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting.. I thought pallas is not supported on v2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gmm is just mm... The kernel is simple... Other fuses softmax, etc...

@alanwaketan
Copy link
Collaborator Author

Thanks Jack for approving.

@alanwaketan
Copy link
Collaborator Author

Skip GPU tests to move fast.

@alanwaketan alanwaketan merged commit aeed61a into master May 30, 2024
@alanwaketan alanwaketan deleted the alanwaketan/tgmm4 branch May 30, 2024 05:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants