Skip to content

Conversation

NexMing
Copy link
Contributor

@NexMing NexMing commented Oct 15, 2025

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir

Author: Ming Yan (NexMing)

Changes

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/163505.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda296da5..f914b292eba83 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(), + op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { + return isa<Value>(ofr) && getConstantIntValue(ofr); + })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(), + op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides()); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 16b7a5c8bcb08..7160b52af6353 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { // ----- +// CHECK-LABEL: func @reinterpret_constant_fold +// CHECK-SAME: (%[[ARG:.*]]: memref<f32>) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] +func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] @@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> @@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-memref

Author: Ming Yan (NexMing)

Changes

Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.


Full diff: https://github.com/llvm/llvm-project/pull/163505.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+26-1)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+21-9)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index e9bdcda296da5..f914b292eba83 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2158,11 +2158,36 @@ struct ReinterpretCastOpExtractStridedMetadataFolder return success(); } }; + +struct ReinterpretCastOpConstantFolder + : public OpRewritePattern<ReinterpretCastOp> { +public: + using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ReinterpretCastOp op, + PatternRewriter &rewriter) const override { + if (!llvm::any_of(llvm::concat<OpFoldResult>(op.getMixedOffsets(), + op.getMixedSizes(), + op.getMixedStrides()), + [](OpFoldResult ofr) { + return isa<Value>(ofr) && getConstantIntValue(ofr); + })) + return failure(); + + auto newReinterpretCast = ReinterpretCastOp::create( + rewriter, op->getLoc(), op.getSource(), op.getConstifiedMixedOffset(), + op.getConstifiedMixedSizes(), op.getConstifiedMixedStrides()); + + rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast); + return success(); + } +}; } // namespace void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context); + results.add<ReinterpretCastOpExtractStridedMetadataFolder, + ReinterpretCastOpConstantFolder>(context); } FailureOr<std::optional<SmallVector<Value>>> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 16b7a5c8bcb08..7160b52af6353 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { // ----- +// CHECK-LABEL: func @reinterpret_constant_fold +// CHECK-SAME: (%[[ARG:.*]]: memref<f32>) +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] +func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c100 = arith.constant 100 : index + %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> + return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1] @@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x // when the strides don't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> @@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // when the offset doesn't match. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -// CHECK: return %[[RES]] +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1] +// CHECK: %[[CAST:.*]] = memref.cast %[[RES]] +// CHECK: return %[[CAST]] func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>> 
…/strides are constants. Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.
@NexMing NexMing force-pushed the dev/reinterpret-constant-fold branch from 92416e8 to f62bd0f Compare October 15, 2025 06:43
@NexMing NexMing enabled auto-merge (squash) October 17, 2025 10:11
@NexMing NexMing merged commit c988bf8 into llvm:main Oct 17, 2025
10 checks passed
@NexMing NexMing deleted the dev/reinterpret-constant-fold branch October 17, 2025 10:19
@clementval
Copy link
Contributor

This test is not triggering the verifier before the canonicalization pattern but after it does. Should the verifier be stricter or should the pattern fails on such op?

func.func @reinterpret_constant_fold2(%arg0: memref<?x?x?xi32>, %arg1 : index) -> memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>> { %c0 = arith.constant 0 : index %c-1 = arith.constant -1 : index %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%arg1, %arg1, %c-1], strides: [%arg1, %arg1, %arg1] : memref<?x?x?xi32> to memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>> return %reinterpret_cast : memref<?x?x?xi32, strided<[?, ?, ?], offset: ?>> } 
@krzysz00
Copy link
Contributor

The size of a memref is not permitted to be negative

Statically negative memref sizes have every right to be an error

@krzysz00
Copy link
Contributor

That is - option 3, the test is UB

@clementval
Copy link
Contributor

The size of a memref is not permitted to be negative

Statically negative memref sizes have every right to be an error

Ok so it could be enforced in the verifier. We are experiencing with some FIR to MemRef passes and Fir represents the dynamic size as -1 where MemRef represents it as std::numeric_limits<int64_t>::min(). So I guess we need to align our representation.

@matthias-springer
Copy link
Member

The canonicalization pattern from this PR must be updated: if the pattern would generate an op that does not verify, it must abort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

5 participants