-
- Notifications
You must be signed in to change notification settings - Fork 11.2k
[NVIDIA] Add SM100 Flashinfer MoE per tensor scale fp8 backend #21458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new FlashInfer backend for per-tensor scaled FP8 Mixture of Experts (MoE), which shows promising performance improvements on SM100 architectures. The changes include adding a new custom operator, refactoring some utility functions into a shared module, and updating the quantization layers to use this new backend.
The code is generally well-structured, and the refactoring of utility functions into flashinfer_utils.py is a good step towards better code organization.
However, there are a couple of areas that could be improved for better maintainability and potentially better performance:
- There is significant code duplication in the logic that invokes the new MoE kernel from both the
Fp8MoEMethodandModelOptFp8MoEMethod. This should be refactored into a shared helper function. - The
tile_tokens_dimparameter for the new kernel is hardcoded, which might not be optimal for all workloads and differs from the dynamic approach used in the existing block-scale kernel.
Addressing these points will enhance the quality and robustness of the new backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be significant code duplication here. The logic inside this if self.flashinfer_moe_enabled: block is nearly identical to the logic in vllm/model_executor/layers/quantization/fp8.py (lines 993-1016).
Duplicating this code block makes future maintenance harder, as changes would need to be applied in two places.
To improve maintainability, I suggest refactoring this shared logic into a common helper function. This function could be placed in a utility module, perhaps vllm/model_executor/layers/quantization/utils/flashinfer_utils.py, and called from both Fp8MoEMethod.apply and ModelOptFp8MoEMethod.apply.
For example, you could create a helper like this:
# In a shared utility file def apply_flashinfer_per_tensor_scale_fp8( layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, e_score_correction_bias: Optional[torch.Tensor], top_k: int, num_expert_group: Optional[int], topk_group: Optional[int], global_num_experts: int, apply_router_weight_on_input: bool, ) -> torch.Tensor: return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( routing_logits=router_logits, routing_bias=e_score_correction_bias, hidden_states=x, input_scale=layer.w13_input_scale, gemm1_weights=layer.w13_weight, gemm1_weights_scale=layer.w13_weight_scale, gemm2_weights=layer.w2_weight, gemm2_weights_scale=layer.w2_weight_scale, activation_scale=layer.w2_input_scale, num_experts=global_num_experts, top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, )This would centralize the logic and make the code cleaner and easier to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like this utility to be implemented to help with drift
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little worried about this line breaking the cuda graph capture because we are creating new tensor on-the-fly. Should we create this zero-bias in the caller instead? Or maybe ask FlashInfer to support routing_bias=None so that we don't need to pass in fake bias.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a blocking issue for now. We will fix this later if we really see it becoming an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, I think asking flashinfer to support routing_bias=None is better probably
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlashInfer has fixed this in 0.2.9rc2. Do you think this is a blocker? If not, I prefer that we merge this PR first and then file another PR after we have upgraded to FlashInfer v0.2.9rc2.
However, if you think this is a blocker, we can wait until FlashINfer v0.2.9rc2 upgrade, which should happen very soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use RoutingMethodType.Llama4 instead of a hard-coded "3"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a blocking issue, just code style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better but the issue with it is that if a different version of flashinfer is installed (or flashinfer isn't installed at all) we'll get an import error. I thought about doing this conversion inside the function after we know that the correct version of flashinfer is installed, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or we can define our class to mimic FlashInfer's class?
has_flashinfer = False try: import flashinfer import flashinfer.fused_moe.RoutingMethodType has_flashinfer = True except ImportError: pass class FlashInferRoutingMethodType(IntEnum): # Default: Softmax -> TopK Default = RoutingMethodType.Default if has_flashinfer else 0 # Renormalize: TopK -> Softmax Renormalize = RoutingMethodType.Renormalize if has_flashinfer else 1 # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups DeepSeekV3 = RoutingMethodType.DeepSeekV3 if has_flashinfer else 2 # Llama4: Top1 -> Sigmoid Llama4 = RoutingMethodType.Llama4 if has_flashinfer else 3 # Qwen3: Softmax -> TopK -> Renormalize RenormalizeNaive = RoutingMethodType.RenormalizeNaive if has_flashinfer else 4 # Unspecified Unspecified = RoutingMethodType.Unspecified if has_flashinfer else 5 There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not critical and can be handled in later PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make this class lazy imported and remove the default arg, so we only need to import it once in the function
| Depends on #21485 |
c3e365c to 872160e Compare Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_expert_group = num_expert_group if num_expert_group is not None else 1 should set to 0 if num_expert_group is None
872160e to fdf635b Compare | local_num_experts: int, | ||
| use_routing_scales_on_input: bool, | ||
| routed_scaling_factor: float = 1.0, | ||
| routing_method_type: int = 3 # Llama4-styled routing method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of defaulting this parameter if it is going to dictate model support. For instance in the current usage of this function this parameter isn't set, but there is no check that this is model needs llama 4 routing i.e. it would be silently incorrect for a Mixtral with the same quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amirkl94 Maybe let's remove the default value for routing_method_type and make this arg a required argument?
And from llama4.py we should pass this into fused_moe.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llama4 already does this by defining its own custom routing function and passing that into FusedMoE
vllm/vllm/model_executor/models/llama4.py
Line 79 in e18f085
| custom_routing_function=Llama4MoE.custom_routing_function, |
I suppose you could just check if
custom_routing_function == Llama4MoE.custom_routing_function There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I can't check custom_routing_function == Llama4MoE.custom_routing_function, unless you meant in llama4.py?
Should I just make this parameter optional and pass it only from llama4 and if it's not passed I'll default to the non-flashinfer implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what @mgoin meant is that in modelopt.py: https://github.com/vllm-project/vllm/blob/185bdd608d24418fae365238b9eb500f8c778241/vllm/model_executor/layers/quantization/modelopt.py#L458
the layer object is just an instance of FusedMoE, so you can dispatch routing_method using:
if layer.routing_method == Llama4MoE.custom_routing_function: routing_method = 3 There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin is this what you meant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this was what I meant. Obviously not optimal, but should be okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mgoin Currently, FlashInfer's per-tensor FP8 MoE only supports Llama4 routing mode, so I told @amirkl94 to assert if layer.routing_method == Llama4MoE.custom_routing_function is True. If it is not, an exception will be raised.
This is done such that in the future if anyone wants to use FlashInfer per-tensor FP8 MoE for another model, it will fail loudly telling the user why that is not supported. My philosophy is: a loud failure is better than a silent corruption.
Could you check if the current implementation is acceptable to you? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make this class lazy imported and remove the default arg, so we only need to import it once in the function
185bdd6 to 6582abc Compare | pipeline failure doesn't seem to be caused by this PR: |
Signed-off-by: mgoin <mgoin64@gmail.com>
| I fixed some issues with the PR and validated acc+performance. I see about 10% throughput improvement on gsm8k on 1xB200 Will do a final review now. |
| The failure is: Doesn't seem to be related to this PR |
| @mgoin The CI errors seem to be unrelated to my PR as I saw they're happening on other branches as well - https://github.com/vllm-project/vllm/pull/21747/commits . |
| Yes, this is what I've found too. I've requested force merge, thank you. |
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Noam Gat <noamgat@gmail.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
…project#21458) Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
Purpose
This PR introduces a new backend for per-tensor scaled MoE from flashinfer. This backend gives a perf improvement as described below.
Accuracy tests
Ran manual
lm_eval gsm8k, using the following command:Results:
Perf tests
Tested on a 1xB200 gpu, using latency benchmark:
Results: