Skip to content

Commit 00a61b3

Browse files
committed
[mlir][ODS] Add new RangedTypesMatchWith operation predicate
This is a variant of TypesMatchWith that provides support for variadic arguments. This is necessary because ranges generally can't use the default operator== comparators for checking equality. Differential Revision: https://reviews.llvm.org/D94574
1 parent a3904cc commit 00a61b3

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

mlir/include/mlir/IR/OpBase.td

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,16 +2191,28 @@ class AllTypesMatch<list<string> names> :
21912191
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
21922192

21932193
// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
2194+
// An optional comparator function may be provided that changes the above form
2195+
// into: `comparator(transform(lhs.getType()), rhs.getType())`.
21942196
class TypesMatchWith<string summary, string lhsArg, string rhsArg,
2195-
string transform> :
2196-
PredOpTrait<summary, CPred<
2197-
!subst("$_self", "$" # lhsArg # ".getType()", transform)
2198-
# " == $" # rhsArg # ".getType()">> {
2197+
string transform, string comparator = "std::equal_to<>()">
2198+
: PredOpTrait<summary, CPred<
2199+
comparator # "(" #
2200+
!subst("$_self", "$" # lhsArg # ".getType()", transform) #
2201+
", $" # rhsArg # ".getType())">> {
21992202
string lhs = lhsArg;
22002203
string rhs = rhsArg;
22012204
string transformer = transform;
22022205
}
22032206

2207+
// Special variant of `TypesMatchWith` that provides a comparator suitable for
2208+
// ranged arguments.
2209+
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
2210+
string transform>
2211+
: TypesMatchWith<summary, lhsArg, rhsArg, transform,
2212+
"[](auto &&lhs, auto &&rhs) { "
2213+
"return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());"
2214+
" }">;
2215+
22042216
// Type Constraint operand `idx`'s Element type is `type`.
22052217
class TCopVTEtIs<int idx, Type type> : And<[
22062218
CPred<"$_op.getNumOperands() > " # idx>,

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,15 @@ def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
17331733
let assemblyFormat = "attr-dict $value `:` type($value)";
17341734
}
17351735

1736+
def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
1737+
RangedTypesMatchWith<"result type matches operand", "value", "result",
1738+
"llvm::make_range($_self.begin(), $_self.end())">
1739+
]> {
1740+
let arguments = (ins Variadic<AnyType>:$value);
1741+
let results = (outs Variadic<AnyType>:$result);
1742+
let assemblyFormat = "attr-dict $value `:` type($value)";
1743+
}
1744+
17361745
def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
17371746
TypesMatchWith<"result type matches constant", "value", "result", "$_self">
17381747
]> {

mlir/test/mlir-tblgen/op-format.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,5 +308,8 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
308308
// CHECK: test.format_types_match_var %[[I64]] : i64
309309
%ignored_res3 = test.format_types_match_var %i64 : i64
310310

311+
// CHECK: test.format_types_match_variadic %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64
312+
%ignored_res4:3 = test.format_types_match_variadic %i64, %i64, %i64 : i64, i64, i64
313+
311314
// CHECK: test.format_types_match_attr 1 : i64
312-
%ignored_res4 = test.format_types_match_attr 1 : i64
315+
%ignored_res5 = test.format_types_match_attr 1 : i64

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,10 +1287,16 @@ void OperationFormat::genParserTypeResolution(Operator &op,
12871287
if (Optional<int> val = resolver.getBuilderIdx()) {
12881288
body << "odsBuildableType" << *val;
12891289
} else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1290-
if (Optional<StringRef> tform = resolver.getVarTransformer())
1291-
body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
1292-
else
1290+
if (Optional<StringRef> tform = resolver.getVarTransformer()) {
1291+
FmtContext fmtContext;
1292+
if (var->isVariadic())
1293+
fmtContext.withSelf(var->name + "Types");
1294+
else
1295+
fmtContext.withSelf(var->name + "Types[0]");
1296+
body << tgfmt(*tform, &fmtContext);
1297+
} else {
12931298
body << var->name << "Types";
1299+
}
12941300
} else if (const NamedAttribute *attr = resolver.getAttribute()) {
12951301
if (Optional<StringRef> tform = resolver.getVarTransformer())
12961302
body << tgfmt(*tform,

0 commit comments

Comments
 (0)