Skip to content

Commit 8409ce0

Browse files
authored
Add default parameters for compatibility (#3354)
* Add default parameters for compatibility * Change interface in docs and _mete_registrations
1 parent 08610f7 commit 8409ce0

File tree

9 files changed

+76
-67
lines changed

9 files changed

+76
-67
lines changed

csrc/cpu/aten/PagedAttention.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ void single_query_cached_kv_attention_forward_cpu(
2424
at::Tensor& context_lens, // [num_seqs]
2525
int64_t block_size,
2626
int64_t max_context_len,
27+
const c10::optional<at::Tensor>& alibi_slopes,
2728
const double k_scale,
28-
const double v_scale,
29-
const c10::optional<at::Tensor>& alibi_slopes) {
29+
const double v_scale) {
3030
return single_query_cached_kv_attention_kernel_stub(
3131
kCPU,
3232
out,
@@ -39,9 +39,9 @@ void single_query_cached_kv_attention_forward_cpu(
3939
context_lens,
4040
block_size,
4141
max_context_len,
42+
alibi_slopes,
4243
k_scale,
43-
v_scale,
44-
alibi_slopes);
44+
v_scale);
4545
}
4646

4747
void reshape_and_cache_cpu(
@@ -68,9 +68,9 @@ void flash_attn_varlen_cpu(
6868
const double softmax_scale,
6969
bool is_causal,
7070
at::Tensor& block_table,
71+
const c10::optional<at::Tensor>& alibi_slopes,
7172
const double k_scale,
72-
const double v_scale,
73-
const c10::optional<at::Tensor>& alibi_slopes) {
73+
const double v_scale) {
7474
return flash_attn_var_len_kernel_stub(
7575
kCPU,
7676
out,
@@ -84,9 +84,9 @@ void flash_attn_varlen_cpu(
8484
softmax_scale,
8585
is_causal,
8686
block_table,
87+
alibi_slopes,
8788
k_scale,
88-
v_scale,
89-
alibi_slopes);
89+
v_scale);
9090
}
9191

9292
} // namespace cpu

csrc/cpu/aten/PagedAttention.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ void single_query_cached_kv_attention(
1919
at::Tensor& context_lens, // [num_seqs]
2020
int64_t block_size,
2121
int64_t max_context_len,
22+
const c10::optional<at::Tensor>& alibi_slopes,
2223
const double k_scale,
23-
const double v_scale,
24-
const c10::optional<at::Tensor>& alibi_slopes);
24+
const double v_scale);
2525
}
2626

2727
void reshape_and_cache(
@@ -45,9 +45,9 @@ void flash_attn_varlen(
4545
const double softmax_scale,
4646
bool is_causal,
4747
at::Tensor& block_table,
48+
const c10::optional<at::Tensor>& alibi_slopes,
4849
const double k_scale,
49-
const double v_scale,
50-
const c10::optional<at::Tensor>& alibi_slopes);
50+
const double v_scale);
5151

5252
using single_query_cached_kv_attention_fn = void (*)(
5353
at::Tensor& out, // [num_seqs, num_heads, head_size]
@@ -60,9 +60,9 @@ using single_query_cached_kv_attention_fn = void (*)(
6060
at::Tensor& context_lens, // [num_seqs]
6161
int64_t block_size,
6262
int64_t max_context_len,
63+
const c10::optional<at::Tensor>& alibi_slopes,
6364
const double k_scale,
64-
const double v_scale,
65-
const c10::optional<at::Tensor>& alibi_slopes);
65+
const double v_scale);
6666

6767
using reshape_and_cache_fn = void (*)(
6868
at::Tensor& key,
@@ -85,9 +85,9 @@ using flash_attn_var_len_fn = void (*)(
8585
const double softmax_scale,
8686
bool is_causal,
8787
at::Tensor& block_table,
88+
const c10::optional<at::Tensor>& alibi_slopes,
8889
const double k_scale,
89-
const double v_scale,
90-
const c10::optional<at::Tensor>& alibi_slopes);
90+
const double v_scale);
9191

9292
IPEX_DECLARE_DISPATCH(
9393
single_query_cached_kv_attention_fn,

csrc/cpu/aten/kernels/PagedAttentionKrnl.cpp

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,10 @@ inline void _mul_reduce_max_fusion_kernel(
409409
* @param block_size The block size which means the number of token in every
410410
* block.
411411
* @param max_context_len Maximum context length.
412-
* @param k_scale Scaling factor for key cache of data type fp8.
413-
* @param v_scale Scaling factor for value cache of data type fp8.
414412
* @param alibi_slopes Optional tensor of alibi slopes with the shape of
415413
* (num_heads).
414+
* @param k_scale Scaling factor for key cache of data type fp8.
415+
* @param v_scale Scaling factor for value cache of data type fp8.
416416
*/
417417
template <typename scalar_t, typename cache_t>
418418
void single_query_cached_kv_attention_kernel(
@@ -425,9 +425,9 @@ void single_query_cached_kv_attention_kernel(
425425
at::Tensor& context_lens,
426426
int64_t block_size,
427427
int64_t max_context_len,
428+
const c10::optional<at::Tensor>& alibi_slopes,
428429
const double k_scale,
429-
const double v_scale,
430-
const c10::optional<at::Tensor>& alibi_slopes) {
430+
const double v_scale) {
431431
auto out_ptr = out.data_ptr<scalar_t>();
432432
auto query_ptr = query.data_ptr<scalar_t>();
433433
auto key_cache_ptr = key_cache.data_ptr<cache_t>();
@@ -807,9 +807,9 @@ void flash_attn_varlen_kernel(
807807
const double softmax_scale, // scale for softmax
808808
bool is_causal, // whether the attention is causal
809809
at::Tensor& block_table,
810+
const c10::optional<at::Tensor>& alibi_slopes,
810811
const double k_scale,
811-
const double v_scale,
812-
const c10::optional<at::Tensor>& alibi_slopes) {
812+
const double v_scale) {
813813
auto kv_block_strideN = key_cache.stride(0);
814814
auto kv_block_strideH = key_cache.stride(1);
815815
auto kv_block_strideP = key_cache.stride(2);
@@ -1027,9 +1027,9 @@ void single_query_cached_kv_attention_kernel_impl(
10271027
at::Tensor& context_lens, // [num_seqs]
10281028
int64_t block_size,
10291029
int64_t max_context_len,
1030+
const c10::optional<at::Tensor>& alibi_slopes,
10301031
const double k_scale,
1031-
const double v_scale,
1032-
const c10::optional<at::Tensor>& alibi_slopes) {
1032+
const double v_scale) {
10331033
RECORD_FUNCTION(
10341034
"ipex::single_query_cached_kv_attention_kernel_impl",
10351035
c10::ArrayRef<c10::IValue>({}));
@@ -1046,9 +1046,9 @@ void single_query_cached_kv_attention_kernel_impl(
10461046
context_lens,
10471047
block_size,
10481048
max_context_len,
1049+
alibi_slopes,
10491050
k_scale,
1050-
v_scale,
1051-
alibi_slopes);
1051+
v_scale);
10521052
} else if (out.scalar_type() == at::ScalarType::Float) {
10531053
single_query_cached_kv_attention_kernel<float, float>(
10541054
out,
@@ -1060,9 +1060,9 @@ void single_query_cached_kv_attention_kernel_impl(
10601060
context_lens,
10611061
block_size,
10621062
max_context_len,
1063+
alibi_slopes,
10631064
k_scale,
1064-
v_scale,
1065-
alibi_slopes);
1065+
v_scale);
10661066
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
10671067
single_query_cached_kv_attention_kernel<at::BFloat16, at::BFloat16>(
10681068
out,
@@ -1074,9 +1074,9 @@ void single_query_cached_kv_attention_kernel_impl(
10741074
context_lens,
10751075
block_size,
10761076
max_context_len,
1077+
alibi_slopes,
10771078
k_scale,
1078-
v_scale,
1079-
alibi_slopes);
1079+
v_scale);
10801080
} else if (out.scalar_type() == at::ScalarType::Half) {
10811081
single_query_cached_kv_attention_kernel<at::Half, at::Half>(
10821082
out,
@@ -1088,9 +1088,9 @@ void single_query_cached_kv_attention_kernel_impl(
10881088
context_lens,
10891089
block_size,
10901090
max_context_len,
1091+
alibi_slopes,
10911092
k_scale,
1092-
v_scale,
1093-
alibi_slopes);
1093+
v_scale);
10941094
} else {
10951095
TORCH_CHECK(
10961096
false, "Unsupported data type for single_query_cached_kv_attention");
@@ -1152,9 +1152,9 @@ void flash_attn_varlen_cpu_kernel_impl(
11521152
const double softmax_scale,
11531153
bool is_causal,
11541154
at::Tensor& block_table,
1155+
const c10::optional<at::Tensor>& alibi_slopes,
11551156
const double k_scale,
1156-
const double v_scale,
1157-
const c10::optional<at::Tensor>& alibi_slopes) {
1157+
const double v_scale) {
11581158
TORCH_CHECK(
11591159
key.scalar_type() == value.scalar_type(),
11601160
"key and value should have the same data type");
@@ -1173,7 +1173,6 @@ void flash_attn_varlen_cpu_kernel_impl(
11731173
if (query.scalar_type() == at::ScalarType::Float) {
11741174
if (max_seqlen_q >= 768) {
11751175
flash_attn_varlen_kernel<float, float, 128>(
1176-
11771176
out,
11781177
query,
11791178
key,
@@ -1185,9 +1184,9 @@ void flash_attn_varlen_cpu_kernel_impl(
11851184
softmax_scale,
11861185
is_causal,
11871186
block_table,
1187+
alibi_slopes,
11881188
k_scale,
1189-
v_scale,
1190-
alibi_slopes);
1189+
v_scale);
11911190
} else if (max_seqlen_q >= 192) {
11921191
flash_attn_varlen_kernel<float, float, 64>(
11931192
out,
@@ -1201,9 +1200,9 @@ void flash_attn_varlen_cpu_kernel_impl(
12011200
softmax_scale,
12021201
is_causal,
12031202
block_table,
1203+
alibi_slopes,
12041204
k_scale,
1205-
v_scale,
1206-
alibi_slopes);
1205+
v_scale);
12071206
} else {
12081207
flash_attn_varlen_kernel<float, float, 32>(
12091208
out,
@@ -1217,9 +1216,9 @@ void flash_attn_varlen_cpu_kernel_impl(
12171216
softmax_scale,
12181217
is_causal,
12191218
block_table,
1219+
alibi_slopes,
12201220
k_scale,
1221-
v_scale,
1222-
alibi_slopes);
1221+
v_scale);
12231222
}
12241223

12251224
} else if (query.scalar_type() == at::ScalarType::BFloat16) {
@@ -1236,9 +1235,9 @@ void flash_attn_varlen_cpu_kernel_impl(
12361235
softmax_scale,
12371236
is_causal,
12381237
block_table,
1238+
alibi_slopes,
12391239
k_scale,
1240-
v_scale,
1241-
alibi_slopes);
1240+
v_scale);
12421241
} else if (max_seqlen_q >= 192) {
12431242
flash_attn_varlen_kernel<at::BFloat16, at::BFloat16, 64>(
12441243
out,
@@ -1252,9 +1251,9 @@ void flash_attn_varlen_cpu_kernel_impl(
12521251
softmax_scale,
12531252
is_causal,
12541253
block_table,
1254+
alibi_slopes,
12551255
k_scale,
1256-
v_scale,
1257-
alibi_slopes);
1256+
v_scale);
12581257
} else {
12591258
flash_attn_varlen_kernel<at::BFloat16, at::BFloat16, 32>(
12601259
out,
@@ -1268,9 +1267,9 @@ void flash_attn_varlen_cpu_kernel_impl(
12681267
softmax_scale,
12691268
is_causal,
12701269
block_table,
1270+
alibi_slopes,
12711271
k_scale,
1272-
v_scale,
1273-
alibi_slopes);
1272+
v_scale);
12741273
}
12751274

12761275
} else {

intel_extension_for_pytorch/_meta_registrations.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def is_channels_last_3d(ten):
130130

131131

132132
@register_meta("reshape_and_cache")
133-
def meta_reshape_and_cache(key, value, key_cache, value_cache, slot_mapping):
133+
def meta_reshape_and_cache(
134+
key, value, key_cache, value_cache, slot_mapping, k_scale, v_scale
135+
):
134136
return None
135137

136138

@@ -147,6 +149,8 @@ def meta_single_query_cached_kv_attention(
147149
block_size,
148150
max_context_len,
149151
alibi_slopes,
152+
k_scale,
153+
v_scale,
150154
):
151155
return None
152156

intel_extension_for_pytorch/llm/modules/mha_fusion.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,9 @@ class PagedAttention:
472472
context_lens,
473473
block_size,
474474
max_context_len,
475+
alibi_slopes,
475476
k_scale,
476477
v_scale,
477-
alibi_slopes
478478
)
479479
480480
This operator is used to be calculated the scale-dot-product based on the paged attention.
@@ -518,9 +518,9 @@ class PagedAttention:
518518
scale,
519519
is_cusal,
520520
block_tables,
521+
alibi_slopes,
521522
key_cache,
522523
val_cache,
523-
alibi_slopes
524524
)
525525
526526
Args:
@@ -539,9 +539,9 @@ class PagedAttention:
539539
is_cusal (bool): Whether to apply causal attention masking. Default is True. False is not supported yet.
540540
block_tables:(torch.Tensor): The mapping table used to mapping the logical sequence
541541
to the physical sequence. The shape should be [batch_size, max_num_blocks_per_seq].
542+
alibi_slopes (torch.Tensor, optinal): which is the alibi slope with the shape of (num_heads).
542543
k_scale (float): The scale used by the fp8 key cache.
543544
v_scale (float): The scale used by the fp8 value cache.
544-
alibi_slopes (torch.Tensor, optinal): which is the alibi slope with the shape of (num_heads).
545545
546546
"""
547547

@@ -555,8 +555,8 @@ def reshape_and_cache(
555555
key_cache: torch.Tensor,
556556
value_cache: torch.Tensor,
557557
slot_mapping: torch.Tensor,
558-
k_scale: float,
559-
v_scale: float,
558+
k_scale: float = 1.0,
559+
v_scale: float = 1.0,
560560
):
561561
return cls.runtime_ops.get_module_from_device(
562562
key.device.type, IPEXCustomOpType.PAGED_ATTENTION, False
@@ -577,9 +577,9 @@ def single_query_cached_kv_attention(
577577
context_lens: torch.Tensor,
578578
block_size: int,
579579
max_context_len: int,
580-
k_scale: float,
581-
v_scale: float,
582580
alibi_slopes: torch.Tensor,
581+
k_scale: float = 1.0,
582+
v_scale: float = 1.0,
583583
):
584584
return cls.runtime_ops.get_module_from_device(
585585
output.device.type, IPEXCustomOpType.PAGED_ATTENTION, False
@@ -594,9 +594,9 @@ def single_query_cached_kv_attention(
594594
context_lens,
595595
block_size,
596596
max_context_len,
597+
alibi_slopes,
597598
k_scale,
598599
v_scale,
599-
alibi_slopes,
600600
)
601601

602602
@classmethod
@@ -613,9 +613,9 @@ def flash_attn_varlen_func(
613613
scale,
614614
is_cusal: bool,
615615
block_tables: torch.Tensor,
616-
k_scale: float,
617-
v_scale: float,
618616
alibi_slopes: torch.Tensor,
617+
k_scale: float = 1.0,
618+
v_scale: float = 1.0,
619619
):
620620
return cls.runtime_ops.get_module_from_device(
621621
output.device.type, IPEXCustomOpType.PAGED_ATTENTION, False
@@ -631,9 +631,9 @@ def flash_attn_varlen_func(
631631
scale,
632632
is_cusal,
633633
block_tables,
634+
alibi_slopes,
634635
k_scale,
635636
v_scale,
636-
alibi_slopes,
637637
)
638638

639639

0 commit comments

Comments
 (0)