-
- Notifications
You must be signed in to change notification settings - Fork 11.2k
[Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model #27474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model #27474
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a dtype mismatch issue for the aiter fused MoE kernel on ROCm when using FP4 models. The change involves viewing the uint8 weight tensor as torch.float4_e2m1fn_x2 before passing it to the kernel. My review identifies a performance improvement opportunity by moving the hasattr check out of the apply method's hot path and into the class initializer. This would prevent redundant checks on every forward pass.
| if hasattr(torch, "float4_e2m1fn_x2"): | ||
| w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) | ||
| w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) | ||
| else: | ||
| w13_weight = layer.w13_weight | ||
| w2_weight = layer.w2_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hasattr(torch, "float4_e2m1fn_x2") check is performed on every forward pass within the apply method. This is inefficient as the result of this check is constant for a given PyTorch environment. This check should be performed only once during initialization to avoid repeated overhead in the critical path of inference.
I recommend moving this check to the __init__ method of the QuarkOCP_MX_MoEMethod class and caching the result in an instance variable. For example:
In __init__:
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)Then, here in apply, you can simplify the logic:
w13_weight = layer.w13_weight w2_weight = layer.w2_weight if self.fp4_dtype: w13_weight = w13_weight.view(self.fp4_dtype) w2_weight = w2_weight.view(self.fp4_dtype)This would be more performant and also makes the code in apply cleaner by removing the if/else block. Since modifying __init__ is outside the current diff, I am not providing a direct code suggestion, but this change is highly recommended for performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per ai suggestion, this viewing can be moved to process_weights_after_loading to avoid being invoked every forward.
| Hi, @maleksan85 @HaiShaw Could you help review the code changes here? Thank you. |
BowenBao left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good with suggestions.
| if hasattr(torch, "float4_e2m1fn_x2"): | ||
| w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2) | ||
| w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2) | ||
| else: | ||
| w13_weight = layer.w13_weight | ||
| w2_weight = layer.w2_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per ai suggestion, this viewing can be moved to process_weights_after_loading to avoid being invoked every forward.
| Hi, @BowenBao |
| Hi, @HaiShaw Could you help review this PR? It fixed the FP4 fused MOE functionality issue. Thank you! |
| just hold on merging.. I believe @zejunchen-zejun has yet pushed the changes. |
1c21289 to 5ed92fd Compare when running aiter fused moe for fp4 model Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
5ed92fd to 15df282 Compare | Thanks @zejunchen-zejun , LGTM |
| Hi, @HaiShaw @SageMoore @BowenBao @gshtras |
| Hi, @HaiShaw @SageMoore @BowenBao @gshtras @LucasWilkinson |

doing view for moe weight to torch.float4_e2m1fn_x2 for aiter FP4 fused moe kernel

with this PR, aiter cannot find the suitable kernel for the weight, whose type is uint8
with this PR, the deepseek fp4 can run successfully