Skip to content

Conversation

@zejunchen-zejun
Copy link
Contributor

@zejunchen-zejun zejunchen-zejun commented Oct 24, 2025

doing view for moe weight to torch.float4_e2m1fn_x2 for aiter FP4 fused moe kernel
with this PR, aiter cannot find the suitable kernel for the weight, whose type is uint8
image

with this PR, the deepseek fp4 can run successfully

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a dtype mismatch issue for the aiter fused MoE kernel on ROCm when using FP4 models. The change involves viewing the uint8 weight tensor as torch.float4_e2m1fn_x2 before passing it to the kernel. My review identifies a performance improvement opportunity by moving the hasattr check out of the apply method's hot path and into the class initializer. This would prevent redundant checks on every forward pass.

Comment on lines 662 to 667
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The hasattr(torch, "float4_e2m1fn_x2") check is performed on every forward pass within the apply method. This is inefficient as the result of this check is constant for a given PyTorch environment. This check should be performed only once during initialization to avoid repeated overhead in the critical path of inference.

I recommend moving this check to the __init__ method of the QuarkOCP_MX_MoEMethod class and caching the result in an instance variable. For example:

In __init__:

self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)

Then, here in apply, you can simplify the logic:

w13_weight = layer.w13_weight w2_weight = layer.w2_weight if self.fp4_dtype: w13_weight = w13_weight.view(self.fp4_dtype) w2_weight = w2_weight.view(self.fp4_dtype)

This would be more performant and also makes the code in apply cleaner by removing the if/else block. Since modifying __init__ is outside the current diff, I am not providing a direct code suggestion, but this change is highly recommended for performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per ai suggestion, this viewing can be moved to process_weights_after_loading to avoid being invoked every forward.

@zejunchen-zejun
Copy link
Contributor Author

Hi, @maleksan85 @HaiShaw

Could you help review the code changes here?

Thank you.

Copy link
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good with suggestions.

Comment on lines 662 to 667
if hasattr(torch, "float4_e2m1fn_x2"):
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
else:
w13_weight = layer.w13_weight
w2_weight = layer.w2_weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per ai suggestion, this viewing can be moved to process_weights_after_loading to avoid being invoked every forward.

@zejunchen-zejun
Copy link
Contributor Author

Hi, @BowenBao
Thank you for review. Let me modify the code here.

@zejunchen-zejun
Copy link
Contributor Author

Hi, @HaiShaw

Could you help review this PR? It fixed the FP4 fused MOE functionality issue.

Thank you!

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 30, 2025
@BowenBao
Copy link
Contributor

just hold on merging.. I believe @zejunchen-zejun has yet pushed the changes.

@zejunchen-zejun zejunchen-zejun force-pushed the zejun/fix_fp4_fused_moe_func_issue_for_rocm branch 3 times, most recently from 1c21289 to 5ed92fd Compare November 3, 2025 06:08
@zejunchen-zejun
Copy link
Contributor Author

Hi, @BowenBao @gshtras
Thank you for review. We have updated the code according to the significant comments. Could you help review?
With this PR, the DS FP4 functionality is ok:
image
Thank you

when running aiter fused moe for fp4 model Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
@zejunchen-zejun zejunchen-zejun force-pushed the zejun/fix_fp4_fused_moe_func_issue_for_rocm branch from 5ed92fd to 15df282 Compare November 3, 2025 07:47
@BowenBao
Copy link
Contributor

BowenBao commented Nov 3, 2025

Thanks @zejunchen-zejun , LGTM

@zejunchen-zejun
Copy link
Contributor Author

Hi, @HaiShaw @SageMoore @BowenBao @gshtras
Could you help merge this PR? Thank you!

@zejunchen-zejun
Copy link
Contributor Author

zejunchen-zejun commented Nov 10, 2025

Hi, @HaiShaw @SageMoore @BowenBao @gshtras @LucasWilkinson
Could you help merge this PR? Thank you!

@gshtras gshtras merged commit b06b947 into vllm-project:main Nov 10, 2025
52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

4 participants