Skip to content

Conversation

@wonjoo-wj
Copy link
Collaborator

@wonjoo-wj wonjoo-wj commented May 14, 2024

Support megacore_mode in paged_attention

JAX reference for megacore_mode: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L318

Test plan:

python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_megacore_modes 

+ TPU CI

@wonjoo-wj
Copy link
Collaborator Author

Locally test is succeeding on my v4-8:

root@t1v-n-4989e8c7-w-0:~/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_megacore_modes . ---------------------------------------------------------------------- Ran 1 test in 3.283s OK root@t1v-n-4989e8c7-w-0:~/pytorch/xla# 

I'll wait for TPU CI to verify the rest.

@wonjoo-wj wonjoo-wj requested review from JackCaoG and alanwaketan May 14, 2024 19:31
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wonjoo-wj
Copy link
Collaborator Author

Thanks for the reviews, merging as all CIs are green.

@wonjoo-wj wonjoo-wj merged commit cbb9e21 into master May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants