@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805 inputShape[innerDimsPos[idx]] *= size;
18061806 auto maskedRead = vector::createReadOrMaskedRead (
18071807 rewriter, loc, packOp.getSource (), inputShape, padValue,
1808- useInBoundsInsteadOfMasking);
1808+ useInBoundsInsteadOfMasking,
1809+ /* inputScalableVecSizes=*/ {});
18091810
18101811 // Create ShapeCastOp.
18111812 SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1878,118 +1879,99 @@ static VectorType getCollapsedVecType(VectorType type,
18781879 return VectorType::get (newShape, type.getElementType (), newScalableFlags);
18791880}
18801881
1881- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1882- // / Vector::TransferReadOp - Reads a vector from the source tensor
1883- // / vector::TransposeOp - Transpose the Source tensor
1884- // / ShapeCastOp - Reshape the data based on the target.
1885- // / vector::TransferWriteOp. - Write the result vector back to the destination
1886- // / tensor.
1887- // / If the vector sizes are not provided:
1888- // / * the vector sizes are determined by the input operand and attributes,
1889- // / * update the inBounds attribute instead of masking.
1882+ // / Vectorize `linalg.unpack` as:
1883+ // / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884+ // /
1885+ // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886+ // / for the xfer_read operation). This is sufficient to infer the other vector
1887+ // / sizes required here.
1888+ // /
1889+ // / If the vector sizes are not provided:
1890+ // / * the vector sizes are determined from the input tensor static shape.
1891+ // / * the inBounds attribute is used instead of masking.
1892+ // /
1893+ // / EXAMPLE (no vector sizes):
1894+ // / ```
1895+ // / %unpack = linalg.unpack %src
1896+ // / inner_dims_pos = [0, 1]
1897+ // / inner_tiles = [8, 8]
1898+ // / into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899+ // / ```
1900+ // / is vectorized as:
1901+ // / ```
1902+ // / %read = vector.transfer_read %src
1903+ // / : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904+ // / %tr = vector.transpose %read, [0, 2, 1, 3]
1905+ // / : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906+ // / %sc = vector.shape_cast %tr
1907+ // / : vector<1x8x1x8xf32> to vector<8x8xf32>
1908+ // / %vector = vector.transfer_write %sc into %dest
1909+ // / : vector<8x8xf32>, tensor<8x8xf32>
1910+ // / ```
18901911static LogicalResult
18911912vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18921913 ArrayRef<int64_t > inputVectorSizes,
1914+ ArrayRef<bool > inputScalableVecDims,
18931915 SmallVectorImpl<Value> &newResults) {
1916+ if (!inputVectorSizes.empty ()) {
1917+ assert (inputVectorSizes.size () == unpackOp.getSourceRank () &&
1918+ " Invalid number of input vector sizes!" );
1919+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1920+ " Incompatible number of vector sizes and vector scalable flags!" );
1921+ }
18941922
18951923 // TODO: Introduce a parent class that will handle the insertion point update.
18961924 OpBuilder::InsertionGuard g (rewriter);
18971925 rewriter.setInsertionPoint (unpackOp);
18981926
18991927 RankedTensorType unpackTensorType = unpackOp.getSourceType ();
19001928
1901- ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1902- ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
19031929 ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
19041930 bool useInBoundsInsteadOfMasking = false ;
1905- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1906-
1907- auto destSize = unpackOp.getDestRank ();
1908-
1909- if (!inputVectorSizes.empty ())
1910- assert (inputVectorSizes.size () == destSize &&
1911- " Incorrect number of input vector sizes" );
1912-
1913- // vectorSizes is the shape of the vector that will be used to do final
1914- // write on the destination tensor. It is set like this: Let's say the
1915- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1916- // Thus:
1917- // 1. vectorSizes = sourceShape.take_front(N)
1918- // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1919- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1920- // innerTiles attribute value.
1921- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1922- if (vectorSizes.empty ()) {
1923- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1924- if (!outerDimsPerm.empty ())
1925- applyPermutationToVector (vectorSizes, outerDimsPerm);
1926- for (auto [i, pos] : llvm::enumerate (innerDimPos))
1927- vectorSizes[pos] *= innerTiles[i];
19281931
1929- useInBoundsInsteadOfMasking = true ;
1930- }
1932+ Location loc = unpackOp->getLoc ();
19311933
1932- // readVectorSizes is the size of tensor used to read and apply mask. It is
1933- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1934- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1935- // size M-N
1936- // Thus:
1937- // - initially: readVectorSizes = vectorInputSizes
1938- // - Divide all the readMaskShape locations pointed by innerDimPos
1939- // by the innerTileSize attribute value.
1940- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1941- // - Append the remaining shape from SS
1942- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1943- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1944- // 128] and outer_dims_perm is [1, 0] then read shape is:
1945- // ReadVectorSizes(initial): [512, 128]
1946- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1947- // = [16, 8]
1948- // After applying outer_dims_perm: [8, 16]
1949- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1950-
1951- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1952-
1953- for (auto [index, size] : enumerate(innerTiles)) {
1954- readVectorSizes[innerDimPos[index]] =
1955- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1956- }
1957- if (!outerDimsPerm.empty ()) {
1958- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1959- }
1960- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1961- sourceShape.end ());
1934+ // Obtain vector sizes for the read operation.
1935+ SmallVector<int64_t > readVectorSizes (inputVectorSizes);
1936+ SmallVector<bool > readScalableVectorFlags (inputScalableVecDims);
19621937
1963- Location loc = unpackOp->getLoc ();
1938+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939+ if (inputVectorSizes.empty ()) {
1940+ if (ShapedType::isDynamicShape (sourceShape))
1941+ return failure ();
1942+
1943+ readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1944+ useInBoundsInsteadOfMasking = true ;
1945+ }
19641946
1947+ // -- Generate the read operation --
19651948 auto padValue = arith::ConstantOp::create (
19661949 rewriter, loc,
19671950 rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
1968-
1969- // Read result, mask if necessary. If transferReadOp shape is not equal
1970- // to shape of source, then a mask is necessary.
19711951 Value readResult = vector::createReadOrMaskedRead (
19721952 rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1973- /* useInBoundsInsteadOfMasking= */ false );
1953+ useInBoundsInsteadOfMasking, readScalableVectorFlags );
19741954
1955+ // -- Generate the transpose operation --
19751956 PackingMetadata packMetadata;
19761957 SmallVector<int64_t > lastDimToInsertPosPerm =
19771958 getUnPackInverseSrcPerm (unpackOp, packMetadata);
1978- // Transpose the appropriate rows to match output.
19791959 vector::TransposeOp transposeOp = vector::TransposeOp::create (
19801960 rewriter, loc, readResult, lastDimToInsertPosPerm);
19811961
1982- // Collapse the vector to the size required by result.
1962+ // -- Generate the shape_cast operation --
19831963 VectorType collapsedVecType = getCollapsedVecType (
19841964 transposeOp.getType (),
19851965 getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
19861966 rewriter.getContext (), packMetadata.reassociations )));
19871967 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
19881968 rewriter, loc, collapsedVecType, transposeOp->getResult (0 ));
19891969
1970+ // -- Generate the write operation --
19901971 Operation *write = createWriteOrMaskedWrite (
19911972 rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
19921973 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1974+
19931975 newResults.push_back (write->getResult (0 ));
19941976 return success ();
19951977}
@@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20161998 assert (succeeded (status) && " failed to reify result shapes" );
20171999 auto maskedRead = vector::createReadOrMaskedRead (
20182000 rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
2019- /* useInBoundsInsteadOfMasking=*/ false );
2001+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
20202002
20212003 // Create Xfer write Op
20222004 Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
20952077 return success ();
20962078}
20972079
2098- // / Need to check if the inner-tiles are static/constant.
2080+ // // This hook considers two cases:
2081+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2082+ // / infered. This is only possible when all shapes are static.
2083+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2084+ // / carry out basic sanity-checking.
20992085static LogicalResult
21002086vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
21012087 ArrayRef<int64_t > inputVectorSizes) {
2088+ // If there are no input vector sizes and all shapes are static, there is
2089+ // nothing left to check.
2090+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2091+ unpackOp.getSourceType ().hasStaticShape ())
2092+ return success ();
21022093
2103- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2104- return !getConstantIntValue (res).has_value ();
2105- })) {
2106- LDBG () << " Inner-tiles must be constant: " << unpackOp;
2094+ // The number of input vector sizes must be equal to:
2095+ // * read-vector-rank
2096+ if (!inputVectorSizes.empty () &&
2097+ (inputVectorSizes.size () != unpackOp.getSourceRank ())) {
2098+ LDBG () << " Incorrect number of input vector sizes" ;
21072099 return failure ();
21082100 }
2109- ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
2110- bool satisfyEmptyCond = inputVectorSizes.empty () &&
2111- unpackOp.getDestType ().hasStaticShape () &&
2112- unpackOp.getSourceType ().hasStaticShape ();
2113- if (!satisfyEmptyCond &&
2114- failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2101+
2102+ // Check the vector sizes for the read operation.
2103+ if (failed (vector::isValidMaskedInputVector (
2104+ unpackOp.getSourceType ().getShape (), inputVectorSizes))) {
2105+ LDBG () << " Invalid vector sizes for the read operation" ;
21152106 return failure ();
2107+ }
21162108
21172109 return success ();
21182110}
@@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24362428 LDBG () << " pad value is not constant: " << packOp;
24372429 return failure ();
24382430 }
2431+
24392432 ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
24402433 bool satisfyEmptyCond = true ;
24412434 if (inputVectorSizes.empty ()) {
@@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
24992492 return success ();
25002493}
25012494
2502- // / Preconditions for scalable vectors. This is quite restrictive - it models
2503- // / the fact that in practice we would only make selected dimensions scalable.
2495+ // / Preconditions for scalable vectors.
2496+ // /
2497+ // / For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498+ // / models the fact that in practice we would only make selected dimensions
2499+ // / scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2500+ // / unconditionally - we are yet to identify meaningful conditions.
25042501static LogicalResult
25052502vectorizeScalableVectorPrecondition (Operation *op,
25062503 ArrayRef<int64_t > inputVectorSizes,
@@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
25162513
25172514 auto linalgOp = dyn_cast<LinalgOp>(op);
25182515
2519- // Cond 1: There's been no need for scalable vectorisation of
2520- // non-linalg Ops so far
2521- if (!linalgOp)
2522- return failure ();
2516+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2517+ // exception of UnpackOp for which there is a dedicated hook.
2518+ if (!linalgOp) {
2519+ return success (isa<linalg::UnPackOp>(op));
2520+ }
25232521
25242522 // Cond 2: There's been no need for more than 2 scalable dims so far
25252523 if (numOfScalableDims > 2 )
@@ -2750,7 +2748,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27502748 })
27512749 .Case <linalg::UnPackOp>([&](auto unpackOp) {
27522750 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2753- inputVectorSizes, results);
2751+ inputVectorSizes,
2752+ inputScalableVecDims, results);
27542753 })
27552754 .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
27562755 return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3142,7 +3141,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31423141 vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
31433142 Value read = mlir::vector::createReadOrMaskedRead (
31443143 rewriter, loc, source, vecType.getShape (), padValue,
3145- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3144+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3145+ /* inputScalableVecSizes=*/ {});
31463146
31473147 // Create write
31483148 auto writeIndices =
0 commit comments