@@ -23,6 +23,7 @@ def mm_k(
2323 CAST_TYPE : tl .constexpr ,
2424 b_dtype : tl .constexpr ,
2525 USE_GDC : tl .constexpr ,
26+ base_k ,
2627):
2728 """
2829 Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@@ -47,32 +48,62 @@ def mm_k(
4748 matrix dtype.
4849 b_dtype: datatype of the B matrix
4950 USE_GDC: Whether to use PDL. True indicates use.
51+ base_k: Base offset along K dimension for current SPLIT_K group
5052 """
5153 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
52- for k in range (tl .cdiv (K , BLOCK_K * SPLIT_K )):
54+
55+ # Step size along K for each iteration
56+ STEP_K = BLOCK_K * SPLIT_K
57+
58+ # Total number of iterations (compile-time constant)
59+ num_iters = tl .cdiv (K , STEP_K )
60+
61+ for k in range (num_iters ):
62+ # Current iteration's global K offset
63+ iter_k = k * STEP_K + base_k
64+
65+ # Check if this iteration is completely valid (no masking needed)
66+ block_end = iter_k + BLOCK_K
67+
5368 if EVEN_K :
54- # pre-fetech lora weight
69+ # K is divisible by BLOCK_K, no masking ever needed
70+ # pre-fetch lora weight
5571 tiled_b = tl .load (b_ptr )
5672 if USE_GDC :
5773 tl .extra .cuda .gdc_wait ()
5874 tiled_a = tl .load (a_ptr )
75+ if CAST_TYPE :
76+ tiled_a = tiled_a .to (b_dtype )
77+ accumulator += tl .dot (tiled_a , tiled_b )
5978 else :
60- tiled_b = tl .load (
61- b_ptr , mask = offset_k [:, None ] < K - k * (BLOCK_K * SPLIT_K ), other = 0
62- )
63- if USE_GDC :
64- tl .extra .cuda .gdc_wait ()
65- tiled_a = tl .load (
66- a_ptr , mask = offset_k [None , :] < K - k * (BLOCK_K * SPLIT_K ), other = 0
67- )
68- if CAST_TYPE :
69- tiled_a = tiled_a .to (b_dtype )
70- accumulator += tl .dot (
71- tiled_a ,
72- tiled_b ,
73- )
74- a_ptr += BLOCK_K * SPLIT_K * ak_stride
75- b_ptr += BLOCK_K * SPLIT_K * bk_stride
79+ # Check if we need element-wise masking
80+ if iter_k >= K :
81+ # Entire block out of range, skip
82+ pass
83+ elif block_end <= K :
84+ # Entire block in range, no masking needed (fast path)
85+ tiled_b = tl .load (b_ptr )
86+ if USE_GDC :
87+ tl .extra .cuda .gdc_wait ()
88+ tiled_a = tl .load (a_ptr )
89+ if CAST_TYPE :
90+ tiled_a = tiled_a .to (b_dtype )
91+ accumulator += tl .dot (tiled_a , tiled_b )
92+ else :
93+ # Partial block, need masking (only last iteration)
94+ k_offsets = tl .arange (0 , BLOCK_K )
95+ mask = iter_k + k_offsets < K
96+ tiled_b = tl .load (b_ptr , mask = mask [:, None ], other = 0.0 )
97+ if USE_GDC :
98+ tl .extra .cuda .gdc_wait ()
99+ tiled_a = tl .load (a_ptr , mask = mask [None , :], other = 0.0 )
100+ if CAST_TYPE :
101+ tiled_a = tiled_a .to (b_dtype )
102+ accumulator += tl .dot (tiled_a , tiled_b )
103+
104+ a_ptr += STEP_K * ak_stride
105+ b_ptr += STEP_K * bk_stride
106+
76107 return accumulator
77108
78109
@@ -178,6 +209,7 @@ def do_expand_kernel(
178209 CAST_TYPE ,
179210 cur_lora_ptr .dtype .element_ty ,
180211 USE_GDC ,
212+ base_k = 0 ,
181213 )
182214
183215 tiled_c = accumulator .to (cur_lora_ptr .dtype .element_ty )
@@ -284,6 +316,7 @@ def do_shrink_kernel(
284316 False ,
285317 cur_lora_ptr .dtype .element_ty ,
286318 False , # USE_GDC is always False in shrink kernel
319+ base_k = pid_sk * BLOCK_K ,
287320 )
288321 # GDC launch dependents hints the runtime system to launch dependent kernels.
289322 if USE_GDC :
0 commit comments