Skip to content

Commit 04e4c73

Browse files
committed
fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 5b96d01 commit 04e4c73

File tree

3 files changed

+27
-10
lines changed

3 files changed

+27
-10
lines changed

vllm/model_executor/layers/fused_moe/all2all_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def maybe_make_prepare_finalize(
150150
hidden_dim_scale=hidden_dim_scale,
151151
in_dtype=in_dtype,
152152
out_dtype=in_dtype,
153-
scale_dtype=torch.float32,
153+
scale_dtype=torch.float32
154+
if quant_config.quant_dtype is not None
155+
else None,
154156
max_private_tokens=None, # For tuning
155157
)
156158

vllm/model_executor/layers/fused_moe/pplx_garden_prepare_finalize.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def pplx_garden_hidden_dim_scale(
2525
quant_dtype: torch.dtype | str | None,
2626
per_act_token_quant: bool,
2727
block_shape: list[int] | None,
28-
) -> int:
28+
) -> int | None:
2929
# For blocked per token: set to
3030
# ceil_div(hidden_dim, block_size) * sizeof(float32)
3131
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
@@ -37,16 +37,16 @@ def pplx_garden_hidden_dim_scale(
3737
if per_act_token_quant:
3838
# per-token (M x 1)
3939
assert block_shape is None
40-
hidden_dim_scale = 1
40+
hidden_dim_scale = 16
4141
elif block_shape is not None:
4242
# per-group (M x K_tiles)
4343
block_size = block_shape[1]
4444
hidden_dim_scale = cdiv(hidden_dim, block_size)
4545
else:
4646
# per-tensor (1 x 1)
47-
hidden_dim_scale = 1
47+
hidden_dim_scale = 16
4848
else:
49-
hidden_dim_scale = 0
49+
hidden_dim_scale = None # 1?
5050

5151
return hidden_dim_scale
5252

@@ -190,7 +190,7 @@ def prepare_async(
190190
expert_x_scale_shape = (
191191
self.num_local_experts,
192192
expert_x.size(1),
193-
round_up(final_dim, 4), # round up for alignment
193+
round_up(final_dim, 16), # round up for alignment
194194
)
195195

196196
expert_x_scale = torch.empty(
@@ -203,7 +203,11 @@ def prepare_async(
203203
# There's not much point setting this unless it is != indices.size(0)
204204
bound_m: torch.Tensor | None = None
205205

206-
logger.debug("PPLX_GARDEN dispatch send %s", expert_x.shape)
206+
logger.debug(
207+
"PPLX_GARDEN dispatch send %s, %s",
208+
expert_x.shape,
209+
expert_x_scale.shape if expert_x_scale is not None else None,
210+
)
207211

208212
self.a2a.dispatch(
209213
out_expert_num_tokens=expert_num_tokens,
@@ -269,7 +273,8 @@ def _receiver(
269273
"PPLX_GARDEN receive X_SCALE %s",
270274
expert_x_scale.shape if expert_x_scale is not None else None,
271275
)
272-
logger.debug("PPLX_GARDEN receive META %s", expert_tokens_meta)
276+
logger.debug("PPLX_GARDEN receive num_tokens %s", expert_num_tokens.shape)
277+
# logger.debug("PPLX_GARDEN receive META %s", expert_tokens_meta)
273278

274279
return expert_x, expert_x_scale, expert_tokens_meta, None, None
275280

@@ -332,11 +337,13 @@ def finalize_async(
332337

333338
logger.debug("PPLX_GARDEN combine send")
334339

340+
hidden_dim = output.size(1)
341+
335342
self.a2a.combine(
336343
out_tokens=output,
337344
indices=topk_ids_u32,
338345
weights=topk_weights,
339-
expert_y=fused_expert_output,
346+
expert_y=fused_expert_output.view(-1, hidden_dim),
340347
bound_m=bound_m,
341348
do_send=True,
342349
do_recv=False,
@@ -349,7 +356,7 @@ def finalize_async(
349356
out_tokens=output,
350357
indices=topk_ids_u32,
351358
weights=topk_weights,
352-
expert_y=fused_expert_output,
359+
expert_y=fused_expert_output.view(-1, hidden_dim),
353360
bound_m=bound_m,
354361
do_send=False,
355362
do_recv=True,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,9 @@ def select_gemm_impl(
10401040
BatchedTritonOrDeepGemmExperts,
10411041
TritonOrDeepGemmExperts,
10421042
)
1043+
# from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
1044+
# NaiveBatchedExperts,
1045+
# )
10431046

10441047
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
10451048
"Marlin and ROCm AITER are not supported with all2all yet."
@@ -1061,6 +1064,11 @@ def select_gemm_impl(
10611064
self.weight_block_size,
10621065
False,
10631066
)
1067+
# return NaiveBatchedExperts(
1068+
# max_num_tokens=max_num_tokens_per_rank,
1069+
# num_dispatchers=prepare_finalize.num_dispatchers(),
1070+
# quant_config=self.moe_quant_config,
1071+
# )
10641072
return BatchedTritonOrDeepGemmExperts(
10651073
max_num_tokens=max_num_tokens_per_rank,
10661074
num_dispatchers=prepare_finalize.num_dispatchers(),

0 commit comments

Comments
 (0)