Skip to content

Commit b039bfd

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (vllm-project#28366)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent d0e186c commit b039bfd

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

csrc/quantization/activation_kernels.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
578578

579579
// This kernel currently only supports H % 128 == 0 and assumes a
580580
// fixed GROUP_SIZE of 128.
581+
static constexpr int GROUP_SIZE = 128;
582+
581583
TORCH_CHECK(input.dtype() == torch::kBFloat16);
582584
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
583585
y_q.dtype() == torch::kFloat8_e4m3fnuz);
584586
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
585-
TORCH_CHECK(input.size(-1) % 256 == 0);
587+
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
586588

587589
using Idx_t = int64_t;
588590

@@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
601603

602604
Idx_t stride_counts_e = tokens_per_expert.stride(0);
603605

604-
static constexpr int GROUP_SIZE = 128;
605-
606606
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
607607

608608
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
@@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
628628

629629
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
630630

631+
int const NUM_GROUPS = H / GROUP_SIZE;
631632
if (!use_ue8m0) {
632-
if (H >= 4096) {
633+
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
634+
/* 8 warps config */
633635
static constexpr int NUM_STAGES = 4;
634636
static constexpr int THREAD_COUNT = 256;
635637
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
636638
} else {
639+
/* 1 warp config */
637640
static constexpr int THREAD_COUNT = 32;
638641
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
639642
}
640643
} else {
641-
if (H >= 4096) {
644+
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
645+
/* 8 warps config */
642646
static constexpr int NUM_STAGES = 4;
643647
static constexpr int THREAD_COUNT = 256;
644648
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
645649
} else {
650+
/* 1 warp config */
646651
static constexpr int THREAD_COUNT = 32;
647652
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
648653
}

tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
(8, 16, 128 * 2, fp8_dtype),
2626
(8, 16, 128 * 3, fp8_dtype),
2727
(8, 64, 7168, fp8_dtype),
28+
(8, 128, 128 * 33, fp8_dtype),
2829
(8, 128, 7168, fp8_dtype),
2930
(8, 512, 7168, fp8_dtype),
3031
(8, 1024, 7168, fp8_dtype),
@@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
5455
)
5556

5657
# Run the SiLU V2 kernel
58+
# TODO (varun): use_e8m0 is set to false as the reference impl does
59+
# not handle that case.
5760
y_q, y_s = persistent_masked_m_silu_mul_quant(
58-
y, tokens_per_expert, group_size=group_size
61+
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
5962
)
6063

6164
torch.cuda.synchronize()

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
100100
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
101101
num_parallel_tokens=16,
102102
group_size: int = 128,
103+
use_ue8m0: bool | None = None,
103104
) -> tuple[torch.Tensor, torch.Tensor]:
104105
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
105106
y has shape (E, T, 2*H). The first half of the last dimension is
@@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
164165
device=y.device,
165166
)
166167

167-
use_ue8m0 = is_deep_gemm_e8m0_used()
168+
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
168169

169170
cuda_arch = current_platform.get_device_capability(
170171
device_id=y.device.index

0 commit comments

Comments
 (0)