-
Couldn't load subscription status.
- Fork 560
Support the case where query_len>context_len in the multi-queries paged attention. #8356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Hi @WoosukKwon , could you take a look at the test and the modified ref impl when you get a chance? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the fix!
| @vanbasten23 If possible, can you merge the PR today so that I can use this features in the nightly wheel tomorrow? |
| @vanbasten23 can you fix the linter? |
Sorry for the late reply as I was OOO last Friday. Let me try to merge it today. |
| hm, after I take into account the effective_q_lens in the kernel, I found the results from the ref impl and the kernel mismatch. Need to look into it. |
| Ok, I fixed the kernel and the test. cc @WoosukKwon If the test |
| self.fail(f'Unsupported dtype: {dtype}') | ||
| for b in range(batch_size): | ||
| # N.B. For the output ([batch_size, query_len, num_q_heads, head_dim]) at query_len dim, all the value after the effective_q_len will be thrown away due to we padding the query seq len. The values after the effective_q_len may differ between the kernel and the ref impl because of the causal mask. | ||
| effective_q_len = effective_q_lens[b] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon , the kernel makes sure that the results up to effective_q_len are correct and the kernel makes no guarantee on the results beyond effective_q_len
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Thanks for the clarification!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
| In this PR, the test fails on TPU v4 (but succeeded TPU v5e): Notice the failure comes from the reference impl (pure jax, not Pallas kernel related) |
| As Jack mentioned, it might be due to OOM in the kernel but the error is not surfaced until later because the kernel execution is async. |
This PR support the case where query_len>context len. In reality, vLLM may pad the query_len (but not for the kv seq len). That means the query_len may be longer than the kv_len. This PR
Test plans: