Skip to content

Commit 40e2eee

Browse files
caozuobajeejeelee
andauthored
[Kernel] Optimization of the mm_k operator. (vllm-project#28280)
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent b06b947 commit 40e2eee

File tree

1 file changed

+51
-18
lines changed

1 file changed

+51
-18
lines changed

vllm/lora/ops/triton_ops/kernel_utils.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def mm_k(
2323
CAST_TYPE: tl.constexpr,
2424
b_dtype: tl.constexpr,
2525
USE_GDC: tl.constexpr,
26+
base_k,
2627
):
2728
"""
2829
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@@ -47,32 +48,62 @@ def mm_k(
4748
matrix dtype.
4849
b_dtype: datatype of the B matrix
4950
USE_GDC: Whether to use PDL. True indicates use.
51+
base_k: Base offset along K dimension for current SPLIT_K group
5052
"""
5153
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
52-
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
54+
55+
# Step size along K for each iteration
56+
STEP_K = BLOCK_K * SPLIT_K
57+
58+
# Total number of iterations (compile-time constant)
59+
num_iters = tl.cdiv(K, STEP_K)
60+
61+
for k in range(num_iters):
62+
# Current iteration's global K offset
63+
iter_k = k * STEP_K + base_k
64+
65+
# Check if this iteration is completely valid (no masking needed)
66+
block_end = iter_k + BLOCK_K
67+
5368
if EVEN_K:
54-
# pre-fetech lora weight
69+
# K is divisible by BLOCK_K, no masking ever needed
70+
# pre-fetch lora weight
5571
tiled_b = tl.load(b_ptr)
5672
if USE_GDC:
5773
tl.extra.cuda.gdc_wait()
5874
tiled_a = tl.load(a_ptr)
75+
if CAST_TYPE:
76+
tiled_a = tiled_a.to(b_dtype)
77+
accumulator += tl.dot(tiled_a, tiled_b)
5978
else:
60-
tiled_b = tl.load(
61-
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
62-
)
63-
if USE_GDC:
64-
tl.extra.cuda.gdc_wait()
65-
tiled_a = tl.load(
66-
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
67-
)
68-
if CAST_TYPE:
69-
tiled_a = tiled_a.to(b_dtype)
70-
accumulator += tl.dot(
71-
tiled_a,
72-
tiled_b,
73-
)
74-
a_ptr += BLOCK_K * SPLIT_K * ak_stride
75-
b_ptr += BLOCK_K * SPLIT_K * bk_stride
79+
# Check if we need element-wise masking
80+
if iter_k >= K:
81+
# Entire block out of range, skip
82+
pass
83+
elif block_end <= K:
84+
# Entire block in range, no masking needed (fast path)
85+
tiled_b = tl.load(b_ptr)
86+
if USE_GDC:
87+
tl.extra.cuda.gdc_wait()
88+
tiled_a = tl.load(a_ptr)
89+
if CAST_TYPE:
90+
tiled_a = tiled_a.to(b_dtype)
91+
accumulator += tl.dot(tiled_a, tiled_b)
92+
else:
93+
# Partial block, need masking (only last iteration)
94+
k_offsets = tl.arange(0, BLOCK_K)
95+
mask = iter_k + k_offsets < K
96+
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
97+
if USE_GDC:
98+
tl.extra.cuda.gdc_wait()
99+
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
100+
if CAST_TYPE:
101+
tiled_a = tiled_a.to(b_dtype)
102+
accumulator += tl.dot(tiled_a, tiled_b)
103+
104+
a_ptr += STEP_K * ak_stride
105+
b_ptr += STEP_K * bk_stride
106+
76107
return accumulator
77108

78109

@@ -178,6 +209,7 @@ def do_expand_kernel(
178209
CAST_TYPE,
179210
cur_lora_ptr.dtype.element_ty,
180211
USE_GDC,
212+
base_k=0,
181213
)
182214

183215
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
@@ -284,6 +316,7 @@ def do_shrink_kernel(
284316
False,
285317
cur_lora_ptr.dtype.element_ty,
286318
False, # USE_GDC is always False in shrink kernel
319+
base_k=pid_sk * BLOCK_K,
287320
)
288321
# GDC launch dependents hints the runtime system to launch dependent kernels.
289322
if USE_GDC:

0 commit comments

Comments
 (0)