Skip to content

Commit de2b783

Browse files
authored
[ROCm] Add env to enable/disable aiter triton gemm (vllm-project#28321)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent e5e9067 commit de2b783

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
114114
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
115115
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
116+
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True
116117
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
117118
VLLM_ROCM_FP8_PADDING: bool = True
118119
VLLM_ROCM_MOE_PADDING: bool = True
@@ -944,6 +945,11 @@ def get_vllm_port() -> int | None:
944945
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
945946
in ("true", "1")
946947
),
948+
# Whether to use aiter triton kernels for gemm ops.
949+
# By default is enabled.
950+
"VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: (
951+
os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1")
952+
),
947953
# use rocm skinny gemms
948954
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
949955
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
@@ -1586,6 +1592,7 @@ def compute_hash() -> str:
15861592
"VLLM_ROCM_USE_TRITON_ROPE",
15871593
"VLLM_ROCM_USE_AITER_FP8BMM",
15881594
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
1595+
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
15891596
"VLLM_ROCM_USE_SKINNY_GEMM",
15901597
"VLLM_ROCM_FP8_PADDING",
15911598
"VLLM_ROCM_MOE_PADDING",

vllm/model_executor/layers/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def default_unquantized_gemm(
106106
def use_aiter_triton_gemm(n, m, k, dtype):
107107
if (
108108
envs.VLLM_ROCM_USE_AITER == 0
109+
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
109110
# MI300's - fp8nuz=True
110111
or current_platform.is_fp8_fnuz()
111112
or dtype not in [torch.float16, torch.bfloat16]

0 commit comments

Comments
 (0)