1818#include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1919#include " mlir/IR/BuiltinAttributes.h"
2020#include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/OpDefinition.h"
2122#include " mlir/IR/TypeUtilities.h"
2223#include " mlir/IR/Value.h"
2324#include " mlir/Transforms/DialectConversion.h"
@@ -37,16 +38,17 @@ using namespace mlir;
3738
3839// / Returns a compressed mask. The mask value is set only if any mask is present
3940// / in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
40- // / equals to 2, the following mask:
41+ // / equals to 1 (intraDataOffset strictly smaller than scale), the following
42+ // / mask:
4143// /
42- // / %mask = [1, 1, 1 , 0, 0, 0]
44+ // / %mask = [1, 1, 0 , 0, 0, 0]
4345// /
4446// / will first be padded with number of `intraDataOffset` zeros:
45- // / %mask = [0, 0 , 1, 1, 1 , 0, 0, 0]
47+ // / %mask = [0, 1 , 1, 0, 0 , 0, 0, 0]
4648// /
4749// / then it will return the following new compressed mask:
4850// /
49- // / %mask = [0 , 1, 1 , 0]
51+ // / %mask = [1 , 1, 0 , 0]
5052static FailureOr<Operation *> getCompressedMaskOp (OpBuilder &rewriter,
5153 Location loc, Value mask,
5254 int origElements, int scale,
@@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7577 shape.back () = numElements;
7678 auto newMaskType = VectorType::get (shape, rewriter.getI1Type ());
7779 if (createMaskOp) {
78- // TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
79- if (intraDataOffset != 0 )
80- return failure ();
8180 OperandRange maskOperands = createMaskOp.getOperands ();
8281 size_t numMaskOperands = maskOperands.size ();
8382 AffineExpr s0;
@@ -129,26 +128,79 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
129128 return newMask;
130129}
131130
132- static Value extractSubvectorFrom (RewriterBase &rewriter, Location loc,
133- VectorType extractType, Value vector,
134- int64_t frontOffset, int64_t subvecSize) {
131+ // / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
132+ // / emitting `vector.extract_strided_slice`.
133+ static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
134+ VectorType extractType, Value source,
135+ int64_t frontOffset,
136+ int64_t subvecSize) {
137+ auto vectorType = cast<VectorType>(source.getType ());
138+ assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
139+ " expected 1-D source and destination types" );
135140 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
136141 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
137142 auto strides = rewriter.getI64ArrayAttr ({1 });
138143 return rewriter
139- .create <vector::ExtractStridedSliceOp>(loc, extractType, vector , offsets,
144+ .create <vector::ExtractStridedSliceOp>(loc, extractType, source , offsets,
140145 sizes, strides)
141146 ->getResult (0 );
142147}
143148
144- static Value insertSubvectorInto (RewriterBase &rewriter, Location loc,
145- Value src, Value dest, int64_t offset) {
149+ // / Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
150+ // / at `offset`. it is a wrapper function for emitting
151+ // / `vector.insert_strided_slice`.
152+ static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
153+ Value src, Value dest, int64_t offset) {
154+ auto srcType = cast<VectorType>(src.getType ());
155+ auto destType = cast<VectorType>(dest.getType ());
156+ assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
157+ " expected source and dest to be vector type" );
146158 auto offsets = rewriter.getI64ArrayAttr ({offset});
147159 auto strides = rewriter.getI64ArrayAttr ({1 });
148160 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
149161 dest, offsets, strides);
150162}
151163
164+ // / Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
165+ // / and size `numElementsToExtract`, and inserts into the `dest` vector. This
166+ // / function emits multiple `vector.extract` and `vector.insert` ops, so only
167+ // / use it when `offset` cannot be folded into a constant value.
168+ static Value dynamicallyExtractSubVector (OpBuilder &rewriter, Location loc,
169+ TypedValue<VectorType> source,
170+ Value dest, OpFoldResult offset,
171+ int64_t numElementsToExtract) {
172+ for (int i = 0 ; i < numElementsToExtract; ++i) {
173+ Value extractLoc =
174+ (i == 0 ) ? offset.dyn_cast <Value>()
175+ : rewriter.create <arith::AddIOp>(
176+ loc, rewriter.getIndexType (), offset.dyn_cast <Value>(),
177+ rewriter.create <arith::ConstantIndexOp>(loc, i));
178+ auto extractOp =
179+ rewriter.create <vector::ExtractOp>(loc, source, extractLoc);
180+ dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, i);
181+ }
182+ return dest;
183+ }
184+
185+ // / Returns the op sequence for an emulated sub-byte data type vector load.
186+ // / specifically, use `emulatedElemType` for loading a vector of `origElemType`.
187+ // / The load location is given by `base` and `linearizedIndices`, and the
188+ // / load size is given by `numEmulatedElementsToLoad`.
189+ static TypedValue<VectorType>
190+ emulatedVectorLoad (OpBuilder &rewriter, Location loc, Value base,
191+ OpFoldResult linearizedIndices,
192+ int64_t numEmultedElementsToLoad, Type origElemType,
193+ Type emulatedElemType) {
194+ auto scale = emulatedElemType.getIntOrFloatBitWidth () /
195+ origElemType.getIntOrFloatBitWidth ();
196+ auto newLoad = rewriter.create <vector::LoadOp>(
197+ loc, VectorType::get (numEmultedElementsToLoad, emulatedElemType), base,
198+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
199+ return rewriter.create <vector::BitCastOp>(
200+ loc, VectorType::get (numEmultedElementsToLoad * scale, origElemType),
201+ newLoad);
202+ };
203+
152204namespace {
153205
154206// ===----------------------------------------------------------------------===//
@@ -380,25 +432,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
380432 ? getConstantIntValue (linearizedInfo.intraDataOffset )
381433 : 0 ;
382434
383- if (!foldedIntraVectorOffset) {
384- // unimplemented case for dynamic intra vector offset
385- return failure ();
386- }
387-
435+ // Always load enough elements which can cover the original elements.
436+ int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
388437 auto numElements =
389- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
390- auto newLoad = rewriter.create <vector::LoadOp>(
391- loc, VectorType::get (numElements, newElementType), adaptor.getBase (),
392- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
393-
394- Value result = rewriter.create <vector::BitCastOp>(
395- loc, VectorType::get (numElements * scale, oldElementType), newLoad);
396-
397- if (isUnalignedEmulation) {
398- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
399- *foldedIntraVectorOffset, origElements);
438+ llvm::divideCeil (maxintraDataOffset + origElements, scale);
439+ Value result =
440+ emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
441+ numElements, oldElementType, newElementType);
442+
443+ if (foldedIntraVectorOffset) {
444+ if (isUnalignedEmulation) {
445+ result =
446+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
447+ *foldedIntraVectorOffset, origElements);
448+ }
449+ } else {
450+ auto resultVector = rewriter.create <arith::ConstantOp>(
451+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
452+ result = dynamicallyExtractSubVector (
453+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
454+ linearizedInfo.intraDataOffset , origElements);
400455 }
401-
402456 rewriter.replaceOp (op, result);
403457 return success ();
404458 }
@@ -513,8 +567,8 @@ struct ConvertVectorMaskedLoad final
513567 // create an empty vector of the new type
514568 auto emptyVector = rewriter.create <arith::ConstantOp>(
515569 loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
516- passthru = insertSubvectorInto (rewriter, loc, passthru, emptyVector,
517- *foldedIntraVectorOffset);
570+ passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
571+ *foldedIntraVectorOffset);
518572 }
519573 auto newPassThru =
520574 rewriter.create <vector::BitCastOp>(loc, loadType, passthru);
@@ -537,16 +591,17 @@ struct ConvertVectorMaskedLoad final
537591 // TODO: can fold if op's mask is constant
538592 auto emptyVector = rewriter.create <arith::ConstantOp>(
539593 loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
540- mask = insertSubvectorInto (rewriter, loc, op.getMask (), emptyVector,
541- *foldedIntraVectorOffset);
594+ mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyVector,
595+ *foldedIntraVectorOffset);
542596 }
543597
544598 Value result =
545599 rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
546600
547601 if (isUnalignedEmulation) {
548- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
549- *foldedIntraVectorOffset, origElements);
602+ result =
603+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
604+ *foldedIntraVectorOffset, origElements);
550605 }
551606 rewriter.replaceOp (op, result);
552607
@@ -604,13 +659,10 @@ struct ConvertVectorTransferRead final
604659 ? getConstantIntValue (linearizedInfo.intraDataOffset )
605660 : 0 ;
606661
607- if (!foldedIntraVectorOffset) {
608- // unimplemented case for dynamic inra-vector offset
609- return failure ();
610- }
611-
662+ auto maxIntraVectorOffset =
663+ foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1 ;
612664 auto numElements =
613- llvm::divideCeil (*foldedIntraVectorOffset + origElements, scale);
665+ llvm::divideCeil (maxIntraVectorOffset + origElements, scale);
614666
615667 auto newRead = rewriter.create <vector::TransferReadOp>(
616668 loc, VectorType::get (numElements, newElementType), adaptor.getSource (),
@@ -621,9 +673,18 @@ struct ConvertVectorTransferRead final
621673 loc, VectorType::get (numElements * scale, oldElementType), newRead);
622674
623675 Value result = bitCast->getResult (0 );
624- if (isUnalignedEmulation) {
625- result = extractSubvectorFrom (rewriter, loc, op.getType (), result,
626- *foldedIntraVectorOffset, origElements);
676+ if (foldedIntraVectorOffset) {
677+ if (isUnalignedEmulation) {
678+ result =
679+ staticallyExtractSubvector (rewriter, loc, op.getType (), result,
680+ *foldedIntraVectorOffset, origElements);
681+ }
682+ } else {
683+ auto zeros = rewriter.create <arith::ConstantOp>(
684+ loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
685+ result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
686+ linearizedInfo.intraDataOffset ,
687+ origElements);
627688 }
628689 rewriter.replaceOp (op, result);
629690
0 commit comments