@@ -144,12 +144,13 @@ def prepare(
144144 "apply_router_weight_on_input is only implemented for topk=1" )
145145 a1 = a1 * topk_weights .to (a1 .dtype )
146146
147- if quant_config .per_act_token_quant :
147+ if quant_config .is_block_quantized :
148+ # Quant and Dispatch
148149 a1q , a1q_scale = moe_kernel_quantize_input (
149150 a1 ,
150151 a1_scale ,
151152 quant_dtype = quant_config .quant_dtype ,
152- per_act_token_quant = True ,
153+ per_act_token_quant = quant_config . per_act_token_quant ,
153154 block_shape = quant_config .block_shape ,
154155 )
155156 if a1q_scale is not None and a1q_scale .numel () == 1 :
@@ -162,16 +163,18 @@ def prepare(
162163 rank_topk_weights = topk_weights ,
163164 num_experts = num_experts )
164165 else :
165- # DeepEP kernels only support dispatching per-token-quant
166- # quantization. dispatch in bfloat16.
166+ # Dispatch and Quant
167+ # DeepEP kernels only support dispatching block-quantized
168+ # activation scales.
169+ # Dispatch in bfloat16
167170 (expert_x , _ , expert_tokens_meta , expert_topk_ids ,
168171 expert_topk_weights ) = self ._do_dispatch (
169172 tokens = a1 ,
170173 token_scales = None ,
171174 rank_topk_ids = topk_ids ,
172175 rank_topk_weights = topk_weights ,
173176 num_experts = num_experts )
174- # quantize now
177+ # Quantize after dispatch.
175178 expert_x_scale = None
176179 if expert_x .numel () != 0 :
177180 expert_x , expert_x_scale = moe_kernel_quantize_input (
0 commit comments