@@ -1275,15 +1275,15 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
12751275 from .common import KernelArgType , SizeArg , TensorArg
12761276
12771277 signature : List [KernelArgType ] = []
1278- constants : Dict [int , Any ] = {}
1278+ constants : Dict [str , Any ] = {}
12791279 non_constant_indices = []
1280- equal_to_1_arg_idx : List [int ] = []
1280+ equal_to_1_args : List [str ] = []
12811281 for idx , key in enumerate (kernel .arg_names ):
12821282 if key not in kwargs :
12831283 continue
12841284 arg = kwargs [key ]
12851285 if idx in kernel .constexprs :
1286- constants [idx ] = arg
1286+ constants [key ] = arg
12871287 else :
12881288 non_constant_indices .append (idx )
12891289 if isinstance (arg , ir .Buffer ):
@@ -1313,13 +1313,14 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
13131313 ) and V .graph .sizevars .statically_known_equals (
13141314 arg , 1 # type: ignore[arg-type]
13151315 ):
1316- equal_to_1_arg_idx .append (idx )
1316+ equal_to_1_args .append (key )
13171317 index_dtype = "tl.int32"
13181318 triton_meta = {
13191319 "signature" : signature_to_meta (
13201320 signature ,
13211321 size_dtype = index_dtype ,
13221322 indices = non_constant_indices ,
1323+ argdefs = kernel .arg_names ,
13231324 ),
13241325 "device" : DeviceProperties .create (
13251326 V .graph .scheduler .get_current_device_or_throw ()
@@ -1333,7 +1334,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
13331334 # https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
13341335 "constants" : {
13351336 ** constants ,
1336- ** dict .fromkeys (equal_to_1_arg_idx , 1 ),
1337+ ** dict .fromkeys (equal_to_1_args , 1 ),
13371338 },
13381339 "configs" : [
13391340 config_of (
0 commit comments