Skip to content

Conversation

@FranklandJack
Copy link
Contributor

Extend the load of a expand shape rewrite pattern to support folding a memref.expand_shape and vector.transfer_read when the permutation map on vector.transfer_read is a minor identity.

Extend the load of a expand shape rewrite pattern to support folding a `memref.expand_shape` and `vector.transfer_read` when the permutation map on `vector.transfer_read` is a minor identity. Signed-off-by: Jack Frankland <jack.frankland@arm.com>
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Jack Frankland (FranklandJack)

Changes

Extend the load of a expand shape rewrite pattern to support folding a memref.expand_shape and vector.transfer_read when the permutation map on vector.transfer_read is a minor identity.


Full diff: https://github.com/llvm/llvm-project/pull/167679.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+24-2)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+34)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 214410f78e51c..30df10c1deedc 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -347,28 +347,49 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation())))) return failure(); - llvm::TypeSwitch<Operation *, void>(loadOp) + + return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp) .Case([&](affine::AffineLoadOp op) { rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices); + return success(); }) .Case([&](memref::LoadOp op) { rewriter.replaceOpWithNewOp<memref::LoadOp>( loadOp, expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::LoadOp op) { rewriter.replaceOpWithNewOp<vector::LoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getNontemporal()); + return success(); }) .Case([&](vector::MaskedLoadOp op) { rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, op.getMask(), op.getPassThru()); + return success(); + }) + .Case([&](vector::TransferReadOp op) { + // We only support minor identity maps in the permutation attribute. + if (!op.getPermutationMap().isMinorIdentity()) + return failure(); + + // We need to construct a new minor identity map since we will have lost + // some dimensions in folding away the expand shape. + auto minorIdMap = AffineMap::getMinorIdentityMap( + sourceIndices.size(), op.getVectorType().getRank(), + op.getContext()); + + rewriter.replaceOpWithNewOp<vector::TransferReadOp>( + op, op.getVectorType(), expandShapeOp.getViewSource(), + sourceIndices, minorIdMap, op.getPadding(), op.getMask(), + op.getInBounds()); + return success(); }) .DefaultUnreachable("unexpected operation"); - return success(); } template <typename OpTy> @@ -659,6 +680,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { LoadOpOfExpandShapeOpFolder<memref::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::LoadOp>, LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, + LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>, StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, StoreOpOfExpandShapeOpFolder<memref::StoreOp>, StoreOpOfExpandShapeOpFolder<vector::StoreOp>, diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 106652623933f..87f23457644ae 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -992,6 +992,40 @@ func.func @fold_vector_maskedstore_expand_shape( // ----- +func.func @fold_vector_transfer_read_expand_shape( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad {in_bounds = [true]} : memref<4x8xf32>, vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_expand_shape +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[PAD:.*]] = ub.poison : f32 +// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8) +// CHECK: vector.transfer_read %[[ARG0]][%[[IDX]]], %[[PAD]] {in_bounds = [true]} + +// ----- + +func.func @fold_vector_transfer_read_with_perm_map( + %arg0 : memref<32xf32>, %arg1 : index) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %pad = ub.poison : f32 + %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + %1 = vector.transfer_read %0[%arg1, %c0], %pad { permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<4x8xf32>, vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// CHECK-LABEL: func @fold_vector_transfer_read_with_perm_map +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32> +// CHECK: memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> + +// ----- + func.func @fold_vector_load_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index) -> vector<8xf32> { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants