Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Nov 4, 2024

This PR support the case where query_len>context len. In reality, vLLM may pad the query_len (but not for the kv seq len). That means the query_len may be longer than the kv_len. This PR

  • adds a test for the case where query_len>kv_len
  • modified the kernel to add a effective_q_lens parameter
  • modifies the ref impl to account for the effective_q_lens

Test plans:

  • python pytorch/xla/test/test_pallas.py -v -k PallasTest.test_paged_attention_multi_queries_wrapper
  • python pytorch/xla/test/test_tpu_paged_attention_kernel.py 2>&1 | tee ~/out.txt
@vanbasten23
Copy link
Collaborator Author

Hi @WoosukKwon , could you take a look at the test and the modified ref impl when you get a chance?

@WoosukKwon WoosukKwon self-requested a review November 7, 2024 22:32
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix!

@WoosukKwon
Copy link
Collaborator

@vanbasten23 If possible, can you merge the PR today so that I can use this features in the nightly wheel tomorrow?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 8, 2024

@vanbasten23 can you fix the linter?

@vanbasten23
Copy link
Collaborator Author

@vanbasten23 If possible, can you merge the PR today so that I can use this features in the nightly wheel tomorrow?

Sorry for the late reply as I was OOO last Friday. Let me try to merge it today.

@vanbasten23
Copy link
Collaborator Author

hm, after I take into account the effective_q_lens in the kernel, I found the results from the ref impl and the kernel mismatch. Need to look into it.

@vanbasten23 vanbasten23 marked this pull request as ready for review November 12, 2024 20:08
@vanbasten23 vanbasten23 changed the title Support the case where query_len>context_len Support the case where query_len>context_len in the multi-queries paged attention. Nov 12, 2024
@vanbasten23
Copy link
Collaborator Author

Ok, I fixed the kernel and the test. cc @WoosukKwon If the test test_paged_attention_without_query_padding looks good to you and aligns with how vLLM would use the kernel, then I'll merge it once the CI finishes.

self.fail(f'Unsupported dtype: {dtype}')
for b in range(batch_size):
# N.B. For the output ([batch_size, query_len, num_q_heads, head_dim]) at query_len dim, all the value after the effective_q_len will be thrown away due to we padding the query seq len. The values after the effective_q_len may differ between the kernel and the ref impl because of the causal mask.
effective_q_len = effective_q_lens[b]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon , the kernel makes sure that the results up to effective_q_len are correct and the kernel makes no guarantee on the results beyond effective_q_len

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for the clarification!

@vanbasten23 vanbasten23 requested a review from JackCaoG November 13, 2024 22:56
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@vanbasten23 vanbasten23 merged commit 102cd48 into master Nov 14, 2024
11 of 12 checks passed
@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Nov 19, 2024

In this PR, the test fails on TPU v4 (but succeeded TPU v5e):

root@085d9c0e2005:/ansible# python pytorch/xla/test/test_tpu_paged_attention_kernel.py PagedAttentionKernelTest.test_paged_attention_with_query_padding128 WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. Running tests under Python 3.10.14: /usr/local/bin/python [ RUN ] PagedAttentionKernelTest.test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) Running paged_attention with query_len=2048, kv_seq_lens=Array([1136, 7, 21], dtype=int32), effective_q_lens=Array([720, 6, 12], dtype=int32) [ FAILED ] PagedAttentionKernelTest.test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) ====================================================================== ERROR: test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) (__main__.PagedAttentionKernelTest) PagedAttentionKernelTest.test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) test_paged_attention_with_query_padding(dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) ---------------------------------------------------------------------- Traceback (most recent call last): File "/usr/local/lib/python3.10/site-packages/absl/testing/parameterized.py", line 319, in bound_param_test return test_method(self, **testcase_params) File "/ansible/pytorch/xla/test/test_tpu_paged_attention_kernel.py", line 265, in test_paged_attention_with_query_padding expected_output = _ref_jax_extended_paged_attention( File "/ansible/pytorch/xla/test/test_tpu_paged_attention_kernel.py", line 56, in _ref_jax_extended_paged_attention kv_len = lengths[i] File "/usr/local/lib/python3.10/site-packages/jax/_src/array.py", line 370, in __getitem__ return lax_numpy._rewriting_take(self, idx) File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 11081, in _rewriting_take if (result := _attempt_rewriting_take_via_slice(arr, idx, mode)) is not None: File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 11058, in _attempt_rewriting_take_via_slice arr = lax.slice( File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 107, in slice return slice_p.bind(operand, start_indices=tuple(start_indices), File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 438, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 955, in process_primitive return primitive.impl(*tracers, **params) File "/usr/local/lib/python3.10/site-packages/jax/_src/lax/slicing.py", line 1329, in _slice_impl return dispatch.apply_primitive(dynamic_slice_p, x, *start_indices, File "/usr/local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 91, in apply_primitive outs = fun(*args) File "/usr/local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **p.params) File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 2803, in bind return self.bind_with_trace(top_trace, args, params) File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 442, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/usr/local/lib/python3.10/site-packages/jax/_src/core.py", line 955, in process_primitive return primitive.impl(*tracers, **params) File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1737, in _pjit_call_impl return xc._xla.pjit( File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1713, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "/usr/local/lib/python3.10/site-packages/jax/_src/pjit.py", line 1667, in _pjit_call_impl_python return compiled.unsafe_call(*args), compiled File "/usr/local/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) File "/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1287, in __call__ results = self.xla_executable.execute_sharded(input_bufs) jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program: Fatal error ---------------------------------------------------------------------- Ran 1 test in 7.869s FAILED (errors=1) F1119 22:00:49.772371 541372 b5ee2_core_offload_helper.cc:71] Check failed: ::tensorflow::b5ee2_core_all_gather_offloading:: Shutdownb5ee2CoreForOffloading(&topology, GetDriverWrappers()) is OK (UNKNOWN: Fatal error === Source Location Trace: === platforms/asic_sw/driver/2a886c8/common/internal/host_queue.cc:316 learning/45eac/google/xla/b5ee2_core_all_gather_offloading.cc:82 ) Failed to shutdown b5ee2Core from offloading. *** Check failure stack trace: *** @ 0x7fb712ef0424 (unknown) @ 0x7fb712eeff58 (unknown) @ 0x7fb713121fe9 (unknown) @ 0x7fb709d12a65 (unknown) @ 0x7fb709d0f8f1 (unknown) @ 0x7fb709d0f57b (unknown) @ 0x7fb708460f4c (unknown) @ 0x7fb7084612ee (unknown) @ 0x7fb7083fac5a (unknown) @ 0x7fb95fbefeed std::_Function_handler<>::_M_invoke() @ 0x7fb95fbe2ed6 xla::PjRtCApiClient::~PjRtCApiClient() @ 0x7fb95fbe2fee xla::PjRtCApiClient::~PjRtCApiClient() @ 0x7fb9658f2a50 xla::ifrt::PjRtClient::~PjRtClient() @ 0x7fb95fa987f7 std::_Sp_counted_deleter<>::_M_dispose() @ 0x7fb96508dc3e xla::PyClient::~PyClient() @ 0x7fb9658e2c44 nanobind::detail::inst_dealloc() @ 0x7fba0b50889a gc_collect_main https://symbolize.stripped_domain/r/?trace=7fb712ef0424,7fb712eeff57,7fb713121fe8,7fb709d12a64,7fb709d0f8f0,7fb709d0f57a,7fb708460f4b,7fb7084612ed,7fb7083fac59,7fb95fbefeec,7fb95fbe2ed5,7fb95fbe2fed,7fb9658f2a4f,7fb95fa987f6,7fb96508dc3d,7fb9658e2c43,7fba0b508899&map= Fatal Python error: Aborted Current thread 0x00007fba0b05a740 (most recent call first): Garbage-collecting <no Python frame> Aborted (core dumped) root@085d9c0e2005:/ansible# 

Notice the failure comes from the reference impl (pure jax, not Pallas kernel related)

@vanbasten23
Copy link
Collaborator Author

vanbasten23 commented Nov 20, 2024

As Jack mentioned, it might be due to OOM in the kernel but the error is not surfaced until later because the kernel execution is async.
So after we add a jax.block_until_ready(actual_output), the real error surfaced:

root@085d9c0e2005:/ansible# python pytorch/xla/test/test_tpu_paged_attention_kernel.py PagedAttentionKernelTest.test_paged_attention_with_query_padding128 /usr/local/lib/python3.10/site-packages/jax/__init__.py:31: UserWarning: cloud_tpu_init failed: KeyError('LIBTPU_INIT_ARGS') This a JAX bug; please report an issue at https://github.com/jax-ml/jax/issues _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report " WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. Running tests under Python 3.10.14: /usr/local/bin/python [ RUN ] PagedAttentionKernelTest.test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) INFO:2024-11-20 05:19:01,815:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' I1120 05:19:01.815768 140220782982976 xla_bridge.py:927] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' Running paged_attention with query_len=2048, kv_seq_lens=Array([1136, 7, 21], dtype=int32), effective_q_lens=Array([48, 7, 21], dtype=int32) [ FAILED ] PagedAttentionKernelTest.test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) ====================================================================== ERROR: test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) (__main__.PagedAttentionKernelTest) PagedAttentionKernelTest.test_paged_attention_with_query_padding128 (dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) test_paged_attention_with_query_padding(dtype=<class 'jax.numpy.float32'>, page_size=64, num_kv_heads=8, q_kv_head_ratio=4, head_dim=128, num_queries_per_compute_block=16, block_kv_size=128) ---------------------------------------------------------------------- Traceback (most recent call last): File "/usr/local/lib/python3.10/site-packages/absl/testing/parameterized.py", line 319, in bound_param_test return test_method(self, **testcase_params) File "/ansible/pytorch/xla/test/test_tpu_paged_attention_kernel.py", line 268, in test_paged_attention_with_query_padding jax.block_until_ready(actual_output) File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 2763, in block_until_ready try_to_block(arrays[0]) File "/usr/local/lib/python3.10/site-packages/jax/_src/api.py", line 2746, in try_to_block return x.block_until_ready() jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Program or fatal error occurred; computation may be invalid: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x1075 (from TensorCoreSequencer:1:0x1076): no debugging message found for this tag:pc. HLO: main; HLO computation: main.15 === Source Location Trace: === learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:113 ---------------------------------------------------------------------- Ran 1 test in 6.316s FAILED (errors=1) F1120 05:19:08.782719 839585 b5ee2_core_offload_helper.cc:71] Check failed: ::tensorflow::b5ee2_core_all_gather_offloading:: Shutdownb5ee2CoreForOffloading(&topology, GetDriverWrappers()) is OK (UNKNOWN: Fatal error === Source Location Trace: === platforms/asic_sw/driver/2a886c8/common/internal/host_queue.cc:316 learning/45eac/google/xla/b5ee2_core_all_gather_offloading.cc:82 ) Failed to shutdown b5ee2Core from offloading. *** Check failure stack trace: *** @ 0x7f7c26e8f4c4 (unknown) @ 0x7f7c26e8eff8 (unknown) @ 0x7f7c270c4cc9 (unknown) @ 0x7f7c1daafea5 (unknown) @ 0x7f7c1daacd31 (unknown) @ 0x7f7c1daac9bb (unknown) @ 0x7f7c1c18258c (unknown) @ 0x7f7c1c18292e (unknown) @ 0x7f7c1c11b91a (unknown) @ 0x7f870507953d std::_Function_handler<>::_M_invoke() @ 0x7f870506c556 xla::PjRtCApiClient::~PjRtCApiClient() @ 0x7f870506c66e xla::PjRtCApiClient::~PjRtCApiClient() @ 0x7f870bcd0a50 xla::ifrt::PjRtClient::~PjRtClient() @ 0x7f8704f1f447 std::_Sp_counted_deleter<>::_M_dispose() @ 0x7f870af04dee xla::PyClient::~PyClient() @ 0x7f870bcc0a44 nanobind::detail::inst_dealloc() @ 0x7f87b240989a gc_collect_main https://symbolize.stripped_domain/r/?trace=7f7c26e8f4c4,7f7c26e8eff7,7f7c270c4cc8,7f7c1daafea4,7f7c1daacd30,7f7c1daac9ba,7f7c1c18258b,7f7c1c18292d,7f7c1c11b919,7f870507953c,7f870506c555,7f870506c66d,7f870bcd0a4f,7f8704f1f446,7f870af04ded,7f870bcc0a43,7f87b2409899&map= Fatal Python error: Aborted Current thread 0x00007f87b1f5b740 (most recent call first): Garbage-collecting <no Python frame> Aborted (core dumped) 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

4 participants