Skip to content
33 changes: 25 additions & 8 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7915,6 +7915,26 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
}

// Check that all partial reductions in a chain are only used by other
// partial reductions with the same scale factor. Otherwise we end up creating
// users of scaled reductions where the types of the other operands don't
// match.
for (const auto &[Chain, Scale] : PartialReductionChains) {
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
auto *UI = cast<Instruction>(U);
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
return all_of(UI->users(), [ScaleVal, this](const User *U) {
auto *UI = cast<Instruction>(U);
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
});
}
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
!OrigLoop->contains(UI->getParent());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|| !OrigLoop->contains(UI->getParent())

Is this part of the condition covered by a test-case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is the cover the exit-user of the reduction chain.

};
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx))
ScaledReductionMap.erase(Chain.Reduction);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be done in ExtendIsOnlyUsedByPartialReductions, rather than a loop that removes these? (from what I can see, all the information to make this decision is available in PartialReductionChains)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use the information from PartialReductionChains, but this would include entries that will get rejected later, by the ExtendIsOnlyUsedByPartialReductions. I kept it as 2 separate loops for now.

}
}

bool VPRecipeBuilder::getScaledReductions(
Expand Down Expand Up @@ -8098,11 +8118,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
return tryToWidenMemory(Instr, Operands, Range);

if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr)) {
if (auto PartialRed =
tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value()))
return PartialRed;
}
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
return tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value());

if (!shouldWiden(Instr, Range))
return nullptr;
Expand Down Expand Up @@ -8136,9 +8153,9 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
isa<VPPartialReductionRecipe>(BinOpRecipe))
std::swap(BinOp, Accumulator);

if (ScaleFactor !=
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()))
return nullptr;
assert(ScaleFactor ==
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) &&
"all accumulators in chain must have same scale factor");

unsigned ReductionOpcode = Reduction->getOpcode();
if (ReductionOpcode == Instruction::Sub) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,28 @@ loop:
exit:
ret i32 %red.next
}

define i16 @test_incomplete_chain_without_mul(ptr noalias %dst, ptr %A, ptr %B) #0 {
entry:
br label %loop

loop:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
%red = phi i16 [ 0, %entry ], [ %red.next, %loop ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the two scale factors in this test?

%l.a = load i8, ptr %A, align 1
%a.ext = zext i8 %l.a to i16
store i16 %a.ext, ptr %dst, align 2
%l.b = load i8, ptr %B, align 1
%b.ext = zext i8 %l.b to i16
%add = add i16 %red, %b.ext
%add.1 = add i16 %add, %a.ext
%red.next = add i16 %add.1, %b.ext
%iv.next = add i64 %iv, 1
%ec = icmp ult i64 %iv, 1024
br i1 %ec, label %loop, label %exit

exit:
ret i16 %red.next
}

attributes #0 = { "target-cpu"="grace" }