Skip to content

Conversation

@charlifu
Copy link
Contributor

This PR adds a few fusion passes for Aiter to fusion layernorm + fp8 block quant and silu + fp8 block quant.

Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new fusion passes for ROCm AITer, specifically for layernorm + fp8 block quant and silu + fp8 block quant. This is achieved by adding a new pattern AiterSiluMulFp8BlockQuantPattern and registering a new custom operator. Additionally, the changes in fp8_utils.py extend AITer support to non-MI300 series GPUs by providing a Triton-based fallback, which is a great enhancement.

My main feedback is on a performance concern in fp8_utils.py where an import is performed inside a performance-critical function. I've suggested a refactoring to move the import to the module level to avoid repeated overhead.

Comment on lines 64 to 69
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing inside a function that is on a hot path, like this custom op implementation, can introduce performance overhead. It's best practice to move imports to the module level to ensure they are only executed once.

I'd recommend defining a module-level variable that holds the correct gemm_a8w8_blockscale function based on the platform, and then using that variable within this function. This avoids repeated import lookups.

For example, you could add the following logic at the module level (e.g., near the top of the file):

_gemm_a8w8_blockscale = None if current_platform.is_rocm(): try: # MI300's fp8nuz should be enough to detect if we call ck vs triton if current_platform.is_fp8_fnuz(): from aiter import gemm_a8w8_blockscale else: from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale _gemm_a8w8_blockscale = gemm_a8w8_blockscale except ImportError: # aiter is not installed, which is fine. # The error will be raised when the op is actually used. pass

And then this function's body can be simplified as suggested.

Suggested change
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
if _gemm_a8w8_blockscale is None:
raise ImportError(
"Aiter backend for gemm_a8w8_blockscale not available. "
"Please install aiter.")
return _gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this dispatch outside yeah

@charlifu
Copy link
Contributor Author

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
@ProExpertProg
Copy link
Collaborator

I'm currently overhauling custom op matching in #24604. We also recently added a torch implementation of group quant, could you compare its performance with AITER? Also could you compare the perf of the fused AITER kernel to the fused torch.compile kernel for rmsnorm+quant. Happy to help out with instructions, but overall:

SiluMulFp8StaticQuantPattern,
SiluMulNvfp4QuantPattern)
SiluMulNvfp4QuantPattern,
AiterSiluMulFp8BlockQuantPattern)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This symbol definition is conditional on is_rocm_aiter_linear_enabled():
Any run will fail here if not enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now cd059b9

return x_fp8, out_bs

direct_register_custom_op(
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
Copy link
Collaborator

@tjtanaa tjtanaa Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if the latest aiter allows you to skip direct register custom ops? I remember most ops now should be able to work without calling direct_register_custom_ops on vLLM side as it is done in AITER repository. Moreover, removing the direct_register_custom_ops wrappers can reduce additional CPU overhead. Doing direct_register_custom_ops can be costly in terms of overhead.

Please take a look at the benchmarking results in this PR ROCm#717 (the second and third case) where it shows that removing the direct_register_custom_ops on vLLM side improves the perf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the feedback. Is there a version of aiter which has aiter.ops.triton.fused_fp8_quant and also has these direct_register_custom_ops that you mentioned? I wasn't able to figure out how to call act_mul_and_fp8_group_quant without calling direct_register_custom_op first. Would be happy to investigate further if you can point me in the right direction, otherwise I think we can always come back and get rid of these direct_register_custom_op calls if needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can come back to this in later PR as the 355_wip aiter commit does not have that feature.

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
@mergify
Copy link

mergify bot commented Oct 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @charlifu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 7, 2025
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
@mergify mergify bot removed the needs-rebase label Oct 8, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a look sometime next week, just placing a temp hold while #24604 gets merged

charlifu and others added 8 commits October 21, 2025 21:15
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add the AITER ops to MatcherRMSNorm/MatcherQuantFP8/MatcherSiluMul/... instead of creating separate patterns for the AITER ops, so that we don't need to duplicate these for every pass (think allreduce-rms-quant fusion, all of the rope fusions, etc.)

Comment on lines 62 to 63
def empty_bf16(self, *args, **kws):
return torch.empty(*args, dtype=torch.bfloat16, device=self.device, **kws)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this one should just use empty as that uses the model dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
from .matcher_utils import MatcherAiterFusedMulAdd

class AiterMulAddFusionPattern:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment describing what kind of fusion is done in this pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass removed.

return

def pattern(x: torch.Tensor, a: torch.Tensor, b: torch.Tensor):
mul_add = self.fused_mul_add_matcher.forward_native(x, a, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is abusing the matcher abstraction - it's meant to be a reusable matcher for a simple op and not represent fused/unfused impls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass removed.

):
return self.fused_mul_add_matcher.forward_custom(x, a, b)
else:
return self.fused_mul_add_matcher.forward_native(x, a, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this conditional inside the replacement? The tracing will make this always pick a single branch depending on the graph inputs. I think it would be clearer to add an extra_check parameter to the register_replacement function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass removed.

Comment on lines 64 to 69
# MI300's fp8nuz should be enough to detect if we call ck vs triton
if current_platform.is_fp8_fnuz():
from aiter import gemm_a8w8_blockscale
else:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this dispatch outside yeah

return SiluAndMul.forward_native(x)


if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can live with the fusion pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass removed

if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
self.passes += [MulAddFusionPass(config)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New flag and new file please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass removed.

)


class MulAddFusionPass(VllmPatternMatcherPass):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New file please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass removed

Comment on lines 141 to 143
at1 = auto_functionalized(
SILU_MUL_OP, result=result_silu_mul, input=input
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the MatcherSiluMul

@charlifu
Copy link
Contributor Author

charlifu commented Nov 6, 2025

We found that mul+add fusion is not helping performance. So we are removing this pass.
DeepSeek-R1

no fusion no mul + add all fusion
bs1, in 64, out 512 7.8252 7.4551 7.5100
bs4, in 64, out 512 8.3053 8.037 8.19201
bs8, in 64, out 512 8.156 7.9659 8.2399
charlifu and others added 2 commits November 6, 2025 16:47
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
@mergify
Copy link

mergify bot commented Nov 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @charlifu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 6, 2025
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
@mergify mergify bot removed the needs-rebase label Nov 7, 2025
Signed-off-by: charlifu <charlifu@amd.com>
@mergify
Copy link

mergify bot commented Nov 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @charlifu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 10, 2025
Signed-off-by: charlifu <charlifu@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase rocm Related to AMD ROCm

6 participants