Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ bool isLinearizableVector(VectorType type);
///
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
ArrayRef<int64_t> inputVectorSizes,
std::optional<Value> padValue,
bool useInBoundsInsteadOfMasking = false,
ArrayRef<bool> inputScalableVecDims = {});

Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1771,11 +1771,6 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,

Location loc = packOp.getLoc();
auto padValue = packOp.getPaddingValue();
if (!padValue) {
padValue = arith::ConstantOp::create(
rewriter, loc,
rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
}

// If the input vector sizes are not provided, then the vector sizes are
// determined by the result tensor shape. In case the vector sizes aren't
Expand All @@ -1798,7 +1793,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
rewriter, loc, packOp.getSource(), inputShape,
padValue ? std::optional<Value>(padValue) : std::nullopt,
useInBoundsInsteadOfMasking,
/*inputScalableVecSizes=*/{});

Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ bool vector::isLinearizableVector(VectorType type) {
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
std::optional<Value> padValue,
bool useInBoundsInsteadOfMasking,
ArrayRef<bool> inputScalableVecDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
Expand All @@ -328,9 +328,11 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
inputScalableVecDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
auto vectorType =
VectorType::get(inputVectorSizes, sourceShapedType.getElementType(),
inputScalableVecDims);
assert((!padValue.has_value() ||
padValue.value().getType() == sourceShapedType.getElementType()) &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func.func @test_vectorize_pack(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_2:.*]] = ub.poison : f32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32>
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ func.func @test_vectorize_pack(%src: tensor<32x8x16xf32>, %dest: tensor<4x1x32x1
%pack = linalg.pack %src outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %dest : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32>
return %pack : tensor<4x1x32x16x2xf32>
}
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST:.*]] = ub.poison : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[CST]]
// CHECK-SAME: {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32>
Expand Down Expand Up @@ -1376,7 +1376,7 @@ func.func @test_vectorize_dynamic_pack(%src: tensor<?x?xf32>, %dest: tensor<?x?x
return %pack : tensor<?x?x16x2xf32>
}

// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST:.*]] = ub.poison : f32
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
Expand Down Expand Up @@ -1417,7 +1417,7 @@ func.func @test_vectorize_pack_no_vector_sizes(%src: tensor<64x4xf32>, %dest: te
%pack = linalg.pack %src outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [16, 2] into %dest : tensor<64x4xf32> -> tensor<2x4x16x2xf32>
return %pack : tensor<2x4x16x2xf32>
}
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST:.*]] = ub.poison : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CST]]
// CHECK-SAME: {in_bounds = [true, true]} : tensor<64x4xf32>, vector<64x4xf32>
Expand Down