@@ -1523,6 +1523,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15231523 setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
15241524 setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
15251525 setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1526+ setOperationAction(ISD::OR, VT, Custom);
15261527
15271528 setOperationAction(ISD::SELECT_CC, VT, Expand);
15281529 setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
@@ -13808,8 +13809,128 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
1380813809 return ResultSLI;
1380913810}
1381013811
13812+ /// Try to lower the construction of a pointer alias mask to a WHILEWR.
13813+ /// The mask's enabled lanes represent the elements that will not overlap across
13814+ /// one loop iteration. This tries to match:
13815+ /// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))),
13816+ /// (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size))
13817+ SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG,
13818+ const AArch64Subtarget &Subtarget) {
13819+ if (!Subtarget.hasSVE2())
13820+ return SDValue();
13821+ SDValue LaneMask = Op.getOperand(0);
13822+ SDValue Splat = Op.getOperand(1);
13823+
13824+ if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
13825+ std::swap(LaneMask, Splat);
13826+
13827+ if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13828+ LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask ||
13829+ Splat.getOpcode() != ISD::SPLAT_VECTOR)
13830+ return SDValue();
13831+
13832+ SDValue Cmp = Splat.getOperand(0);
13833+ if (Cmp.getOpcode() != ISD::SETCC)
13834+ return SDValue();
13835+
13836+ CondCodeSDNode *Cond = cast<CondCodeSDNode>(Cmp.getOperand(2));
13837+
13838+ auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1));
13839+ if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 ||
13840+ Cond->get() != ISD::CondCode::SETLT)
13841+ return SDValue();
13842+ unsigned CompValue = std::abs(ComparatorConst->getSExtValue());
13843+ unsigned EltSize = CompValue + 1;
13844+ if (!isPowerOf2_64(EltSize) || EltSize > 8)
13845+ return SDValue();
13846+
13847+ SDValue Diff = Cmp.getOperand(0);
13848+ if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64)
13849+ return SDValue();
13850+
13851+ if (!isNullConstant(LaneMask.getOperand(1)) ||
13852+ (EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA))
13853+ return SDValue();
13854+
13855+ // The number of elements that alias is calculated by dividing the positive
13856+ // difference between the pointers by the element size. An alias mask for i8
13857+ // elements omits the division because it would just divide by 1
13858+ if (EltSize > 1) {
13859+ SDValue DiffDiv = LaneMask.getOperand(2);
13860+ auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1));
13861+ if (!DiffDivConst || DiffDivConst->getZExtValue() != Log2_64(EltSize))
13862+ return SDValue();
13863+ if (EltSize > 2) {
13864+ // When masking i32 or i64 elements, the positive value of the
13865+ // possibly-negative difference comes from a select of the difference if
13866+ // it's positive, otherwise the difference plus the element size if it's
13867+ // negative: pos_diff = diff < 0 ? (diff + 7) : diff
13868+ SDValue Select = DiffDiv.getOperand(0);
13869+ // Make sure the difference is being compared by the select
13870+ if (Select.getOpcode() != ISD::SELECT_CC || Select.getOperand(3) != Diff)
13871+ return SDValue();
13872+ // Make sure it's checking if the difference is less than 0
13873+ if (!isNullConstant(Select.getOperand(1)) ||
13874+ cast<CondCodeSDNode>(Select.getOperand(4))->get() !=
13875+ ISD::CondCode::SETLT)
13876+ return SDValue();
13877+ // An add creates a positive value from the negative difference
13878+ SDValue Add = Select.getOperand(2);
13879+ if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
13880+ return SDValue();
13881+ if (auto *AddConst = dyn_cast<ConstantSDNode>(Add.getOperand(1));
13882+ !AddConst || AddConst->getZExtValue() != EltSize - 1)
13883+ return SDValue();
13884+ } else {
13885+ // When masking i16 elements, this positive value comes from adding the
13886+ // difference's sign bit to the difference itself. This is equivalent to
13887+ // the 32 bit and 64 bit case: pos_diff = diff + sign_bit (diff)
13888+ SDValue Add = DiffDiv.getOperand(0);
13889+ if (Add.getOpcode() != ISD::ADD || Add.getOperand(0) != Diff)
13890+ return SDValue();
13891+ // A logical right shift by 63 extracts the sign bit from the difference
13892+ SDValue Shift = Add.getOperand(1);
13893+ if (Shift.getOpcode() != ISD::SRL || Shift.getOperand(0) != Diff)
13894+ return SDValue();
13895+ if (auto *ShiftConst = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
13896+ !ShiftConst || ShiftConst->getZExtValue() != 63)
13897+ return SDValue();
13898+ }
13899+ } else if (LaneMask.getOperand(2) != Diff)
13900+ return SDValue();
13901+
13902+ SDValue StorePtr = Diff.getOperand(0);
13903+ SDValue ReadPtr = Diff.getOperand(1);
13904+
13905+ unsigned IntrinsicID = 0;
13906+ switch (EltSize) {
13907+ case 1:
13908+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_b;
13909+ break;
13910+ case 2:
13911+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_h;
13912+ break;
13913+ case 4:
13914+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_s;
13915+ break;
13916+ case 8:
13917+ IntrinsicID = Intrinsic::aarch64_sve_whilewr_d;
13918+ break;
13919+ default:
13920+ return SDValue();
13921+ }
13922+ SDLoc DL(Op);
13923+ SDValue ID = DAG.getConstant(IntrinsicID, DL, MVT::i32);
13924+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), ID,
13925+ StorePtr, ReadPtr);
13926+ }
13927+
1381113928SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
1381213929 SelectionDAG &DAG) const {
13930+ if (SDValue SV =
13931+ tryWhileWRFromOR(Op, DAG, DAG.getSubtarget<AArch64Subtarget>()))
13932+ return SV;
13933+
1381313934 if (useSVEForFixedLengthVectorVT(Op.getValueType(),
1381413935 !Subtarget->isNeonAvailable()))
1381513936 return LowerToScalableOp(Op, DAG);
0 commit comments