- Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. #163505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Ming Yan (NexMing) ChangesImplement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. Full diff: https://github.com/llvm/llvm-project/pull/163505.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda296da5..f914b292eba83 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(), + op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { + return isa<Value>(ofr) && getConstantIntValue(ofr); + })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(), + op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides()); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 16b7a5c8bcb08..7160b52af6353 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { // ----- +// CHECK-LABEL: func @reinterpret_constant_fold +// CHECK-SAME: (%[[ARG:.*]]: memref<f32>) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] +func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] @@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> @@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> |
@llvm/pr-subscribers-mlir-memref Author: Ming Yan (NexMing) ChangesImplement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations. Full diff: https://github.com/llvm/llvm-project/pull/163505.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda296da5..f914b292eba83 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(), + op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { + return isa<Value>(ofr) && getConstantIntValue(ofr); + })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(), + op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides()); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 16b7a5c8bcb08..7160b52af6353 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { // ----- +// CHECK-LABEL: func @reinterpret_constant_fold +// CHECK-SAME: (%[[ARG:.*]]: memref<f32>) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] +func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] @@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> @@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> |
…/strides are constants. Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.
92416e8
to f62bd0f
Compare This test is not triggering the verifier before the canonicalization pattern but after it does. Should the verifier be stricter or should the pattern fails on such op?
|
The size of a memref is not permitted to be negative Statically negative memref sizes have every right to be an error |
That is - option 3, the test is UB |
Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as |
The canonicalization pattern from this PR must be updated: if the pattern would generate an op that does not verify, it must abort. |
Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.