Skip to content

Commit e54b559

Browse files
Jokerenpytorchmergebot
authored andcommitted
[inductor] More fixes on the keys of constants and signature dictionaries (pytorch#135406)
Previous PR forgets to change two other places that also create `constants` and `signature`. pytorch#135170 Pull Request resolved: pytorch#135406 Approved by: https://github.com/jansel
1 parent eea5e6f commit e54b559

File tree

10 files changed

+31
-21
lines changed

10 files changed

+31
-21
lines changed

docs/source/torch.compiler_get_started.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ the following:
5757

5858
.. code-block:: python
5959
60-
@pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
60+
@pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
6161
@triton.jit
6262
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
6363
xnumel = 10000

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,11 @@ def decorator(fn):
433433
triton.Config({"XBLOCK": 2}),
434434
],
435435
meta={
436-
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
436+
"signature": {
437+
"in_out_ptr0": "*fp32",
438+
"in_ptr0": "*fp32",
439+
"xnumel": "i32",
440+
},
437441
"device": DeviceProperties.create(torch.device("cuda")),
438442
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
439443
"constants": {},

test/inductor/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
reduction_hint=ReductionHint.INNER,
1515
filename=__file__,
1616
triton_meta={
17-
'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'},
17+
'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'},
1818
'device': 0,
1919
'device_type': 'GPU_TYPE',
2020
'constants': {},

test/inductor/test_triton_heuristics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
102102
tl.store(out_ptr0 + (x0), tmp1, xmask)
103103

104104
triton_meta = {
105-
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
105+
"signature": {"in_ptr0": "*fp32", "out_ptr0": "*fp32", "xnumel": "i32"},
106106
"device": DeviceProperties.create(torch.device("cuda")),
107107
"constants": {},
108108
"configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())],

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,7 +2648,7 @@ def codegen_kernel(self, name=None):
26482648
mutated_args = sorted(mutated_args)
26492649

26502650
triton_meta_signature = signature_to_meta(
2651-
signature, size_dtype=self.index_dtype
2651+
signature, size_dtype=self.index_dtype, argdefs=argdefs
26522652
)
26532653
triton_meta = {
26542654
"signature": triton_meta_signature,
@@ -2676,7 +2676,7 @@ def codegen_kernel(self, name=None):
26762676
for tree in self.active_range_trees():
26772677
sizearg = SizeArg(f"{tree.prefix}numel", tree.numel)
26782678
signature.append(sizearg)
2679-
triton_meta_signature[len(argdefs)] = signature_of(
2679+
triton_meta_signature[sizearg.name] = signature_of(
26802680
sizearg, size_dtype=self.index_dtype
26812681
)
26822682
argdefs.append(f"{tree.prefix}numel")
@@ -2694,7 +2694,7 @@ def codegen_kernel(self, name=None):
26942694
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
26952695
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
26962696
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
2697-
triton_meta["constants"][arg_num] = 1 # type: ignore[index]
2697+
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
26982698

26992699
self.triton_meta = triton_meta
27002700

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,18 +660,19 @@ def jit_line(
660660
heuristics: str,
661661
size_hints: List[int],
662662
selected_kernel: TritonKernel,
663+
signature: List[Any],
664+
argdefs: List[str],
663665
pointwise_with_reduce: bool = False,
664-
signature: Optional[List[Any]] = None,
665666
) -> str:
666667
can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
667668
size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
668-
if signature is None:
669-
_, _, signature, _ = self.args.python_argdefs()
670669
for i, sub in enumerate(self.sub_kernels):
671670
self.min_x_blocks_sub_kernel(sub, i)
672671
self.select_dispatch_strategy()
673672
triton_meta = {
674-
"signature": signature_to_meta(signature, size_dtype=size_dtype),
673+
"signature": signature_to_meta(
674+
signature, size_dtype=size_dtype, argdefs=argdefs
675+
),
675676
"device": DeviceProperties.create(
676677
V.graph.scheduler.get_current_device_or_throw()
677678
),
@@ -850,6 +851,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str:
850851
selected_kernel,
851852
pointwise_with_reduce=pointwise_with_reduction,
852853
signature=signature,
854+
argdefs=argdefs,
853855
)
854856
)
855857
code.writeline(

torch/_inductor/codegen/triton_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ def signature_to_meta(
6868
signature: List[KernelArgType],
6969
*,
7070
size_dtype: str,
71+
argdefs: List[str],
7172
indices: Optional[List[int]] = None,
72-
) -> Dict[int, str]:
73+
) -> Dict[str, str]:
7374
if indices is None:
7475
indices = list(range(len(signature)))
7576
return {
76-
i: signature_of(arg, size_dtype=size_dtype)
77+
argdefs[i]: signature_of(arg, size_dtype=size_dtype)
7778
for i, arg in zip(indices, signature)
7879
}
7980

torch/_inductor/codegen/wrapper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,15 +1275,15 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
12751275
from .common import KernelArgType, SizeArg, TensorArg
12761276

12771277
signature: List[KernelArgType] = []
1278-
constants: Dict[int, Any] = {}
1278+
constants: Dict[str, Any] = {}
12791279
non_constant_indices = []
1280-
equal_to_1_arg_idx: List[int] = []
1280+
equal_to_1_args: List[str] = []
12811281
for idx, key in enumerate(kernel.arg_names):
12821282
if key not in kwargs:
12831283
continue
12841284
arg = kwargs[key]
12851285
if idx in kernel.constexprs:
1286-
constants[idx] = arg
1286+
constants[key] = arg
12871287
else:
12881288
non_constant_indices.append(idx)
12891289
if isinstance(arg, ir.Buffer):
@@ -1313,13 +1313,14 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
13131313
) and V.graph.sizevars.statically_known_equals(
13141314
arg, 1 # type: ignore[arg-type]
13151315
):
1316-
equal_to_1_arg_idx.append(idx)
1316+
equal_to_1_args.append(key)
13171317
index_dtype = "tl.int32"
13181318
triton_meta = {
13191319
"signature": signature_to_meta(
13201320
signature,
13211321
size_dtype=index_dtype,
13221322
indices=non_constant_indices,
1323+
argdefs=kernel.arg_names,
13231324
),
13241325
"device": DeviceProperties.create(
13251326
V.graph.scheduler.get_current_device_or_throw()
@@ -1333,7 +1334,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
13331334
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
13341335
"constants": {
13351336
**constants,
1336-
**dict.fromkeys(equal_to_1_arg_idx, 1),
1337+
**dict.fromkeys(equal_to_1_args, 1),
13371338
},
13381339
"configs": [
13391340
config_of(

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool):
359359
if k == "waves_per_eu":
360360
compile_meta["waves_per_eu"] = v
361361
continue
362-
compile_meta["constants"][self.fn.arg_names.index(k)] = v
362+
compile_meta["constants"][k] = v
363363
compile_meta["num_warps"] = cfg.num_warps
364364
compile_meta["num_stages"] = cfg.num_stages
365365
compile_meta["debug"] = self.inductor_meta.get(

torch/_inductor/select_algorithm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,15 @@ def jit_lines(self):
214214

215215
argdefs, _, signature, _ = self.args.python_argdefs()
216216
triton_meta = {
217-
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
217+
"signature": signature_to_meta(
218+
signature, size_dtype=self.index_dtype, argdefs=argdefs
219+
),
218220
"device": DeviceProperties.create(self.output_node.get_device()),
219221
"constants": {},
220222
}
221223
triton_meta["configs"] = [config_of(signature)]
222224
for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index]
223-
triton_meta["constants"][arg_num] = 1 # type: ignore[index]
225+
triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index]
224226
matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0)
225227
if matrix_instr_nonkdim != 0:
226228
triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim

0 commit comments

Comments
 (0)