Skip to content

[linalg] Vectorization failure - masked scalar read + broadcast #116197

@banach-space

Description

@banach-space

REPRODUCER

func.func @vectorization_test(%extracted_slice : tensor<1x1x3xi32>, %arg0: index, %arg2: index, %3: tensor<2x4xi32>, %4: tensor<1x3x2x4xi32>) -> tensor<1x1x3xi32>{ %c0 = arith.constant 0 :index %8 = linalg.generic {  indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],  iterator_types = ["parallel", "parallel", "parallel"]}  outs(%extracted_slice : tensor<1x1x3xi32>) { ^bb0(%out: i32): %9 = linalg.index 0 : index %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32> %14 = arith.index_cast %extracted : i32 to index %extracted_2 = tensor.extract %4[%c0, %14, %14, %14] : tensor<1x3x2x4xi32>  linalg.yield %extracted_2 : i32 } -> tensor<1x1x3xi32>  return %8 : tensor<1x1x3xi32> } module attributes {transform.with_named_sequence} {  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op // %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op  transform.structured.vectorize %0 vector_sizes [1, 1, 4] {vectorize_nd_extract} : !transform.any_op  transform.yield } }

ERROR LOG

../file.mlir:10:18: error: 'vector.mask' op expects a 'vector<i1>' mask for the maskable operation %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32> ^ ../file.mlir:10:18: note: see current operation: %17 = "vector.mask"(%6) ({ %34 = "vector.transfer_read"(%arg3, %15, %0, %16) <{in_bounds = [true, true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}> : (tensor<2x4xi32>, index, index, i32) -> vector<1x1x4xi32> "vector.yield"(%34) : (vector<1x1x4xi32>) -> () }) : (vector<1x1x4xi1>) -> vector<1x1x4xi32> 

ANALYSIS

  1. Type of vectorization: masked
  2. Op: tensor.extract --> %extracted = tensor.extract %3[%9, %c0] : tensor<2x4xi32>(effectively a scalar read + broadcast).
  3. The Vectorizer output (generated by vectorizeAsTensorExtract):
 %34 = "vector.transfer_read"(%arg3, %15, %0, %16) <{in_bounds = [true, true, true], operandSegmentSizes = array<i32: 1, 2, 1, 0>, permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}> : (tensor<2x4xi32>, index, index, i32) -> vector<1x1x4xi32>

Looks like the mask generated by the vectorizer doesn't match the "expected mask" computed by MaskOp::verify:

  • Linalg vectorizer, when creating vector.mask, generates a mask based on the static loop sizes and input vector sizes. This gives: vector<1x1x4xi1>.

  • The Op verifier uses inferTransferOpMaskType, which has no access to the LinalgOp information and instead looks at the permutation map of the masked op, vector.transfer_read. And that yields vector<i1> (based on permutation_map = affine_map<(d0, d1) -> (0, 0, 0)).

To me, the output from the Vectorizer is correct.

RELEVANT DATA POINT

Looking at this example from "invalid.mlir":

func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
%c1 = arith.constant 1 : i1
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
// expected-note@+1 {{prior use here}}
%mask = vector.splat %c1 : vector<3x8x7xi1>
// expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
}

To me, the error is wrong and the example is in fact correct (as in, vector<3x8x7xi1> as a mask makes sense to me). Specifically, as per the docs (emphasis mine):

The masked-off lanes in the result vector are taken from the corresponding lanes of the pass-thru argument, if provided, or left unmodified, otherwise.

Doesn't this mean that the mask shape should always match the result vector shape?

PROPOSED SOLUTION

  • Either fix how the vectorizer handles broadcast dims, or
  • Update inferTransferOpMaskType and, in general, the semantics of broadcast dims when masking is used.

CC @dcaballe

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions