Skip to content
63 changes: 52 additions & 11 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3648,6 +3648,37 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
Sub = VecOp->getDefiningRecipe();
VecOp = Tmp;
}

// If ValB is a constant and can be safely extended, truncate it to the same
// type as ExtA's operand, then extend it to the same type as ExtA. This
// creates two uniform extends that can more easily be matched by the rest of
// the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
// replaced with the new extend of the constant.
auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
VPWidenCastRecipe *&ExtB,
VPValue *&ValB, VPWidenRecipe *Mul) {
if (!ExtA || ExtB || !ValB->isLiveIn())
return;
Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
Instruction::CastOps ExtOpc = ExtA->getOpcode();
const APInt *Const;
if (!match(ValB, m_APInt(Const)) ||
!llvm::canConstantBeExtended(
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc)))
return;
// The truncate ensures that the type of each extended operand is the
// same, and it's been proven that the constant can be extended from
// NarrowTy safely. Necessary since ExtA's extended operand would be
// e.g. an i8, while the const will likely be an i32. This will be
// elided by later optimisations.
VPBuilder Builder(Mul);
auto *Trunc =
Builder.createWidenCast(Instruction::CastOps::Trunc, ValB, NarrowTy);
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
ValB = ExtB = Builder.createWidenCast(ExtOpc, Trunc, WideTy);
Mul->setOperand(1, ExtB);
};

// Try to match reduce.add(mul(...)).
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
auto *RecipeA =
Expand All @@ -3656,6 +3687,9 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());

// Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
ExtendAndReplaceConstantOp(RecipeA, RecipeB, B, Mul);

// Match reduce.add/sub(mul(ext, ext)).
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
Expand All @@ -3665,7 +3699,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
cast<VPWidenRecipe>(Sub), Red);
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
}
// Match reduce.add(mul).
// TODO: Add an expression type for this variant with a negated mul
if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
Expand All @@ -3674,18 +3707,26 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// variants.
if (Sub)
return nullptr;
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
m_ZExtOrSExt(m_VPValue()))))) {

// Match reduce.add(ext(mul(A, B))).
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
auto *Ext0 =
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
auto *Ext1 =
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());

// reduce.add(ext(mul(ext, const)))
// -> reduce.add(ext(mul(ext, ext(const))))
ExtendAndReplaceConstantOp(Ext0, Ext1, B, Mul);

// reduce.add(ext(mul(ext(A), ext(B))))
// -> reduce.add(mul(wider_ext(A), wider_ext(B)))
// The inner extends must either have the same opcode as the outer extend or
// be the same, in which case the multiply can never result in a negative
// value and the outer extend can be folded away by doing wider
// extends for the operands of the mul.
if (Ext0 && Ext1 &&
(Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
Ext0->getOpcode() == Ext1->getOpcode() &&
IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) {
auto *NewExt0 = new VPWidenCastRecipe(
Expand Down
82 changes: 82 additions & 0 deletions llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,88 @@ exit:
ret i64 %r.0.lcssa
}

define i32 @reduction_expression_ext_mulacc_livein(ptr %a, i16 %c) {
; CHECK-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
; CHECK-SAME: ptr [[A:%.*]], i16 [[C:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: br label %[[VECTOR_PH:.*]]
; CHECK: [[VECTOR_PH]]:
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK: [[VECTOR_BODY]]:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP5:%.*]], %[[VECTOR_BODY]] ]
; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
; CHECK-NEXT: [[TMP2:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = zext <4 x i16> [[TMP2]] to <4 x i32>
; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
; CHECK-NEXT: [[TMP5]] = add i32 [[VEC_PHI]], [[TMP4]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-NEXT: br i1 [[TMP6]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
; CHECK: [[MIDDLE_BLOCK]]:
; CHECK-NEXT: br label %[[FOR_EXIT:.*]]
; CHECK: [[FOR_EXIT]]:
; CHECK-NEXT: ret i32 [[TMP5]]
;
; CHECK-INTERLEAVED-LABEL: define i32 @reduction_expression_ext_mulacc_livein(
; CHECK-INTERLEAVED-SAME: ptr [[A:%.*]], i16 [[C:%.*]]) {
; CHECK-INTERLEAVED-NEXT: [[ENTRY:.*:]]
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_PH:.*]]
; CHECK-INTERLEAVED: [[VECTOR_PH]]:
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i16> poison, i16 [[C]], i64 0
; CHECK-INTERLEAVED-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16> [[BROADCAST_SPLATINSERT]], <4 x i16> poison, <4 x i32> zeroinitializer
; CHECK-INTERLEAVED-NEXT: br label %[[VECTOR_BODY:.*]]
; CHECK-INTERLEAVED: [[VECTOR_BODY]]:
; CHECK-INTERLEAVED-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP8:%.*]], %[[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[VEC_PHI1:%.*]] = phi i32 [ 0, %[[VECTOR_PH]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
; CHECK-INTERLEAVED-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-INTERLEAVED-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[TMP0]], i32 4
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP0]], align 1
; CHECK-INTERLEAVED-NEXT: [[WIDE_LOAD2:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
; CHECK-INTERLEAVED-NEXT: [[TMP2:%.*]] = zext <4 x i8> [[WIDE_LOAD]] to <4 x i16>
; CHECK-INTERLEAVED-NEXT: [[TMP3:%.*]] = zext <4 x i8> [[WIDE_LOAD2]] to <4 x i16>
; CHECK-INTERLEAVED-NEXT: [[TMP4:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP2]]
; CHECK-INTERLEAVED-NEXT: [[TMP5:%.*]] = mul <4 x i16> [[BROADCAST_SPLAT]], [[TMP3]]
; CHECK-INTERLEAVED-NEXT: [[TMP6:%.*]] = zext <4 x i16> [[TMP4]] to <4 x i32>
; CHECK-INTERLEAVED-NEXT: [[TMP7:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP6]])
; CHECK-INTERLEAVED-NEXT: [[TMP8]] = add i32 [[VEC_PHI]], [[TMP7]]
; CHECK-INTERLEAVED-NEXT: [[TMP9:%.*]] = zext <4 x i16> [[TMP5]] to <4 x i32>
; CHECK-INTERLEAVED-NEXT: [[TMP10:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP9]])
; CHECK-INTERLEAVED-NEXT: [[TMP11]] = add i32 [[VEC_PHI1]], [[TMP10]]
; CHECK-INTERLEAVED-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
; CHECK-INTERLEAVED-NEXT: [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
; CHECK-INTERLEAVED-NEXT: br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP32:![0-9]+]]
; CHECK-INTERLEAVED: [[MIDDLE_BLOCK]]:
; CHECK-INTERLEAVED-NEXT: [[BIN_RDX:%.*]] = add i32 [[TMP11]], [[TMP8]]
; CHECK-INTERLEAVED-NEXT: br label %[[FOR_EXIT:.*]]
; CHECK-INTERLEAVED: [[FOR_EXIT]]:
; CHECK-INTERLEAVED-NEXT: ret i32 [[BIN_RDX]]
;
entry:
br label %for.body

for.body: ; preds = %for.body, %entry
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i16
%mul = mul i16 %c, %ext.a
%mul.ext = zext i16 %mul to i32
%add = add i32 %mul.ext, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1024
br i1 %exitcond.not, label %for.exit, label %for.body

for.exit: ; preds = %for.body
ret i32 %add
}

declare float @llvm.fmuladd.f32(float, float, float)

!6 = distinct !{!6, !7, !8}
Expand Down
Loading
Loading