Skip to content

Conversation

@zpcore
Copy link
Member

@zpcore zpcore commented Feb 21, 2025

When we execute the following two code snippets regarding flash attention kernel in custom_kernel.py, they suppose to produce the same result.
a.

l, m = (v[..., 0] for v in aux[-2:]) 

b.

l, m = aux[-2:] l = torch.ops.aten.slice(l, -1, 0, 1) m = torch.ops.aten.slice(m, -1, 0, 1) 

Both will be lowered through

at::Tensor XLANativeFunctions::as_strided_copy(
, the difference is that input argument stride and size will be one element fewer in code a compared with code b. With such argument difference, code a will be fallback into aten::take and this can trigger the following error when we call with SPMD:

F0223 07:18:45.157172 842998 hlo_sharding.cc:1024] Check failed: !IsManual() 

I plan to check in test_as_stride_use_slice.py in this PR.
Note 1. Failing test test_scan_layer_aot is not enabled until #8742 is resolved.
2. Failing test test_scan_weight_layer_aot is not enabled unti #8753 is resolved

@zpcore zpcore changed the title slice lower Lower as_strided_copy use fast path with slice Feb 23, 2025
@zpcore zpcore marked this pull request as ready for review February 23, 2025 08:17
@zpcore zpcore requested a review from tengyifei February 23, 2025 08:19
@zpcore zpcore requested a review from bhavya01 February 24, 2025 17:37
if (stride_mul != stride[j]) {
if (skip_dim == -1) {
skip_dim = i;
K = stride[j] / stride_mul;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that stride[j] can be evenly divided by stride_mul and exit if the remainder is not 0?

Copy link
Member Author

@zpcore zpcore Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to be 'evenly divided' for stride[j] as long as all indexes before j of stride matches with the cumulative product of tensor dim.

@tengyifei tengyifei merged commit 1ab8216 into master Feb 27, 2025
23 checks passed
@zpcore zpcore deleted the piz/as_stride branch February 27, 2025 05:10
pgmoka pushed a commit that referenced this pull request Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants