Skip to content

Commit 8c05b5c

Browse files
authored
[mlir][Affine] Cancel delinearize_index ops fully reversed by apply (#163440)
If an `affine.apply` uses every result of an `affine.delinearize_index` operaration in an expresession of the form x_0 * S_0 + x_1 * S_1 + ... + x_n * S_n + ..., where S_i is the "stride" of the i-th delinerization result (the value it got divided by), then, that chain of additions contains the inverse of the affine.delinearize_index. We don't want to compose affine.delinearize_index into affine.apply in general, since this leads to "simplifications" (mainly the `x % y => x - (x / y) * y` rewrite) thate are bad for code generation and algetbraic reasoning. However, if we do see an exact inverse, we should cancel it out.
1 parent d58b5a6 commit 8c05b5c

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,141 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
11251125
return success(*map != initialMap);
11261126
}
11271127

1128+
/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
1129+
/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
1130+
/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
1131+
/// into `replacementsMap`. If no entries were added to `replacementsMap`,
1132+
/// nothing was found.
1133+
static void shortenAddChainsContainingAll(
1134+
AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1135+
AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
1136+
auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1137+
if (!binOp)
1138+
return;
1139+
AffineExpr lhs = binOp.getLHS();
1140+
AffineExpr rhs = binOp.getRHS();
1141+
if (binOp.getKind() != AffineExprKind::Add) {
1142+
shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
1143+
shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
1144+
return;
1145+
}
1146+
SmallVector<AffineExpr> toPreserve;
1147+
llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1148+
AffineExpr thisTerm = rhs;
1149+
AffineExpr nextTerm = lhs;
1150+
1151+
while (thisTerm) {
1152+
if (!ourTracker.erase(thisTerm)) {
1153+
toPreserve.push_back(thisTerm);
1154+
shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
1155+
replacementsMap);
1156+
}
1157+
auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1158+
if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
1159+
thisTerm = nextTerm;
1160+
nextTerm = AffineExpr();
1161+
} else {
1162+
thisTerm = nextBinOp.getRHS();
1163+
nextTerm = nextBinOp.getLHS();
1164+
}
1165+
}
1166+
if (!ourTracker.empty())
1167+
return;
1168+
// We reverse the terms to be preserved here in order to preserve
1169+
// associativity between them.
1170+
AffineExpr newExpr = newVal;
1171+
for (AffineExpr preserved : llvm::reverse(toPreserve))
1172+
newExpr = newExpr + preserved;
1173+
replacementsMap.insert({e, newExpr});
1174+
}
1175+
1176+
/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
1177+
/// ...` (not necessarily in order) where the set of the `x_i` is the set of
1178+
/// outputs of an `affine.delinearize_index` whos inverse is that expression,
1179+
/// replace that expression with the input of that delinearize_index op.
1180+
///
1181+
/// `unitDimInput` is the input that was detected as the potential start to this
1182+
/// replacement chain - if it isn't the rightmost result of the delinearization,
1183+
/// this method fails. (This is intended to ensure we don't have redundant scans
1184+
/// over the same expression).
1185+
///
1186+
/// While this currently only handles delinearizations with a constant basis,
1187+
/// that isn't a fundamental limitation.
1188+
///
1189+
/// This is a utility function for `replaceDimOrSym` below.
1190+
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
1191+
AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
1192+
SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
1193+
if (!delinOp.getDynamicBasis().empty())
1194+
return failure();
1195+
if (resultToReplace != delinOp.getMultiIndex().back())
1196+
return failure();
1197+
1198+
MLIRContext *ctx = delinOp.getContext();
1199+
SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
1200+
for (auto [pos, dim] : llvm::enumerate(dims)) {
1201+
auto asResult = dyn_cast_if_present<OpResult>(dim);
1202+
if (!asResult)
1203+
continue;
1204+
if (asResult.getOwner() == delinOp.getOperation())
1205+
resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
1206+
}
1207+
for (auto [pos, sym] : llvm::enumerate(syms)) {
1208+
auto asResult = dyn_cast_if_present<OpResult>(sym);
1209+
if (!asResult)
1210+
continue;
1211+
if (asResult.getOwner() == delinOp.getOperation())
1212+
resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
1213+
}
1214+
if (llvm::is_contained(resToExpr, AffineExpr()))
1215+
return failure();
1216+
1217+
bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1218+
int64_t stride = 1;
1219+
llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1220+
// This isn't zip_equal since sometimes the delinearize basis is missing a
1221+
// size for the first result.
1222+
for (auto [binding, size] : llvm::zip(
1223+
llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1224+
expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
1225+
stride *= size;
1226+
}
1227+
if (resToExpr.size() != delinOp.getStaticBasis().size())
1228+
expectedExprs.insert(resToExpr[0] * stride);
1229+
1230+
DenseMap<AffineExpr, AffineExpr> replacements;
1231+
AffineExpr delinInExpr = isDimReplacement
1232+
? getAffineDimExpr(dims.size(), ctx)
1233+
: getAffineSymbolExpr(syms.size(), ctx);
1234+
1235+
for (AffineExpr e : map->getResults())
1236+
shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
1237+
if (replacements.empty())
1238+
return failure();
1239+
1240+
AffineMap origMap = *map;
1241+
if (isDimReplacement)
1242+
dims.push_back(delinOp.getLinearIndex());
1243+
else
1244+
syms.push_back(delinOp.getLinearIndex());
1245+
*map = origMap.replace(replacements, dims.size(), syms.size());
1246+
1247+
// Blank out dead dimensions and symbols
1248+
for (AffineExpr e : resToExpr) {
1249+
if (auto d = dyn_cast<AffineDimExpr>(e)) {
1250+
unsigned pos = d.getPosition();
1251+
if (!map->isFunctionOfDim(pos))
1252+
dims[pos] = nullptr;
1253+
}
1254+
if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255+
unsigned pos = s.getPosition();
1256+
if (!map->isFunctionOfSymbol(pos))
1257+
syms[pos] = nullptr;
1258+
}
1259+
}
1260+
return success();
1261+
}
1262+
11281263
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
11291264
/// defining AffineApplyOp expression and operands.
11301265
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1292,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
11571292
syms);
11581293
}
11591294

1295+
if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1296+
return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
1297+
syms);
1298+
}
1299+
11601300
auto affineApply = v.getDefiningOp<AffineApplyOp>();
11611301
if (!affineApply)
11621302
return failure();

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
22352235

22362236
// -----
22372237

2238+
// CHECK-LABEL: func @delin_apply_cancel_exact
2239+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
2240+
// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]]
2241+
// CHECK-NOT: memref.store
2242+
// CHECK: return
2243+
func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) {
2244+
%a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index
2245+
%b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
2246+
%c:2 = affine.delinearize_index %arg0 into (20) : index, index
2247+
2248+
%t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0]
2249+
memref.store %t1, %arg1[%t1] : memref<?xindex>
2250+
2251+
%t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0]
2252+
memref.store %t2, %arg1[%t2] : memref<?xindex>
2253+
2254+
%t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1]
2255+
memref.store %t3, %arg1[%t3] : memref<?xindex>
2256+
2257+
%t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0]
2258+
memref.store %t4, %arg1[%t4] : memref<?xindex>
2259+
2260+
%t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0]
2261+
memref.store %t5, %arg1[%t5] : memref<?xindex>
2262+
2263+
%t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0]
2264+
memref.store %t6, %arg1[%t5] : memref<?xindex>
2265+
2266+
return
2267+
}
2268+
2269+
// -----
2270+
2271+
// CHECK-LABEL: func @delin_apply_cancel_exact_dim
2272+
// CHECK: affine.for %[[arg1:.+]] = 0 to 256
2273+
// CHECK: memref.store %[[arg1]]
2274+
// CHECK: return
2275+
func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) {
2276+
affine.for %arg1 = 0 to 256 {
2277+
%a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index
2278+
%i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1)
2279+
memref.store %i, %arg0[%i] : memref<?xindex>
2280+
}
2281+
return
2282+
}
2283+
2284+
// -----
2285+
2286+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)>
2287+
// CHECK-LABEL: func @delin_apply_cancel_const_term
2288+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
2289+
// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
2290+
// CHECK: return
2291+
func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) {
2292+
%a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
2293+
2294+
%t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1]
2295+
memref.store %t1, %arg1[%t1] : memref<?xindex>
2296+
2297+
return
2298+
}
2299+
2300+
// -----
2301+
2302+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)>
2303+
// CHECK-LABEL: func @delin_apply_cancel_var_term
2304+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index)
2305+
// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]]
2306+
// CHECK: return
2307+
func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) {
2308+
%a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
2309+
2310+
%t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2]
2311+
memref.store %t1, %arg1[%t1] : memref<?xindex>
2312+
2313+
return
2314+
}
2315+
2316+
// -----
2317+
2318+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)>
2319+
// CHECK-LABEL: func @delin_apply_cancel_nested_exprs
2320+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
2321+
// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
2322+
// CHECK: return
2323+
func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) {
2324+
%a:2 = affine.delinearize_index %arg0 into (20) : index, index
2325+
2326+
%t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0]
2327+
memref.store %t1, %arg1[%t1] : memref<?xindex>
2328+
2329+
return
2330+
}
2331+
2332+
// -----
2333+
2334+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
2335+
// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation
2336+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
2337+
// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20)
2338+
// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]]
2339+
// CHECK: return
2340+
func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) {
2341+
%a:2 = affine.delinearize_index %arg0 into (20) : index, index
2342+
2343+
%t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0]
2344+
memref.store %t1, %arg1[%t1] : memref<?xindex>
2345+
2346+
return
2347+
}
2348+
2349+
// -----
2350+
2351+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)>
2352+
// CHECK-LABEL: func @delin_apply_dont_cancel_partial
2353+
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
2354+
// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5)
2355+
// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1]
2356+
// CHECK: return
2357+
func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) {
2358+
%a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
2359+
2360+
%t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1]
2361+
memref.store %t1, %arg1[%t1] : memref<?xindex>
2362+
2363+
return
2364+
}
2365+
2366+
// -----
2367+
22382368
// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
22392369
// CHECK-SAME: (%[[ARG0:.*]]: index)
22402370
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index

0 commit comments

Comments
 (0)