diff --git a/llvm/docs/AArch64SME.rst b/llvm/docs/AArch64SME.rst index b5a01cb204b81..ac8ce32ddb9e6 100644 --- a/llvm/docs/AArch64SME.rst +++ b/llvm/docs/AArch64SME.rst @@ -213,12 +213,14 @@ Instruction Selection Nodes .. code-block:: none - AArch64ISD::SMSTART Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask] - AArch64ISD::SMSTOP Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask] - -The ``SMSTART/SMSTOP`` nodes take ``CurrentState`` and ``ExpectedState`` operand for -the case of a conditional SMSTART/SMSTOP. The instruction will only be executed -if CurrentState != ExpectedState. + AArch64ISD::SMSTART Chain, [SM|ZA|Both][, RegMask] + AArch64ISD::SMSTOP Chain, [SM|ZA|Both][, RegMask] + AArch64ISD::COND_SMSTART Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask] + AArch64ISD::COND_SMSTOP Chain, [SM|ZA|Both], CurrentState, ExpectedState[, RegMask] + +The ``COND_SMSTART/COND_SMSTOP`` nodes additionally take ``CurrentState`` and +``ExpectedState``, in this case the instruction will only be executed if +``CurrentState != ExpectedState``. When ``CurrentState`` and ``ExpectedState`` can be evaluated at compile-time (i.e. they are both constants) then an unconditional ``smstart/smstop`` diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 293292d47dd48..d1000dd64bdf7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2726,6 +2726,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::VG_RESTORE) MAKE_CASE(AArch64ISD::SMSTART) MAKE_CASE(AArch64ISD::SMSTOP) + MAKE_CASE(AArch64ISD::COND_SMSTART) + MAKE_CASE(AArch64ISD::COND_SMSTOP) MAKE_CASE(AArch64ISD::RESTORE_ZA) MAKE_CASE(AArch64ISD::RESTORE_ZT) MAKE_CASE(AArch64ISD::SAVE_ZT) @@ -6033,14 +6035,12 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op, return DAG.getNode( AArch64ISD::SMSTART, DL, MVT::Other, Op->getOperand(0), // Chain - DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), - DAG.getConstant(AArch64SME::Always, DL, MVT::i64)); + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32)); case Intrinsic::aarch64_sme_za_disable: return DAG.getNode( AArch64ISD::SMSTOP, DL, MVT::Other, Op->getOperand(0), // Chain - DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), - DAG.getConstant(AArch64SME::Always, DL, MVT::i64)); + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32)); } } @@ -8913,18 +8913,22 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL, SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask()); SDValue MSROp = DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32); - SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64); - SmallVector Ops = {Chain, MSROp, ConditionOp}; + SmallVector Ops = {Chain, MSROp}; + unsigned Opcode; if (Condition != AArch64SME::Always) { + SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64); + Opcode = Enable ? AArch64ISD::COND_SMSTART : AArch64ISD::COND_SMSTOP; assert(PStateSM && "PStateSM should be defined"); + Ops.push_back(ConditionOp); Ops.push_back(PStateSM); + } else { + Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP; } Ops.push_back(RegMask); if (InGlue) Ops.push_back(InGlue); - unsigned Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP; return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops); } @@ -9189,9 +9193,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (DisableZA) Chain = DAG.getNode( - AArch64ISD::SMSTOP, DL, MVT::Other, Chain, - DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), - DAG.getConstant(AArch64SME::Always, DL, MVT::i64)); + AArch64ISD::SMSTOP, DL, DAG.getVTList(MVT::Other, MVT::Glue), Chain, + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32)); // Adjust the stack pointer for the new arguments... // These operations are automatically eliminated by the prolog/epilog pass @@ -9668,9 +9671,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (CallAttrs.requiresEnablingZAAfterCall()) // Unconditionally resume ZA. Result = DAG.getNode( - AArch64ISD::SMSTART, DL, MVT::Other, Result, - DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32), - DAG.getConstant(AArch64SME::Always, DL, MVT::i64)); + AArch64ISD::SMSTART, DL, DAG.getVTList(MVT::Other, MVT::Glue), Result, + DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32)); if (ShouldPreserveZT0) Result = diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index c1e6d70099fa5..59a9d7d179778 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -73,6 +73,8 @@ enum NodeType : unsigned { SMSTART, SMSTOP, + COND_SMSTART, + COND_SMSTOP, RESTORE_ZA, RESTORE_ZT, SAVE_ZT, diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td index 363ecee49c0f2..e7482da001074 100644 --- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td @@ -10,12 +10,20 @@ // //===----------------------------------------------------------------------===// -def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 2, - [SDTCisInt<0>, SDTCisInt<0>]>, +def AArch64_smstart : SDNode<"AArch64ISD::SMSTART", SDTypeProfile<0, 1, + [SDTCisInt<0>]>, [SDNPHasChain, SDNPSideEffect, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>; -def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 2, - [SDTCisInt<0>, SDTCisInt<0>]>, +def AArch64_smstop : SDNode<"AArch64ISD::SMSTOP", SDTypeProfile<0, 1, + [SDTCisInt<0>]>, + [SDNPHasChain, SDNPSideEffect, SDNPVariadic, + SDNPOptInGlue, SDNPOutGlue]>; +def AArch64_cond_smstart : SDNode<"AArch64ISD::COND_SMSTART", SDTypeProfile<0, 3, + [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, + [SDNPHasChain, SDNPSideEffect, SDNPVariadic, + SDNPOptInGlue, SDNPOutGlue]>; +def AArch64_cond_smstop : SDNode<"AArch64ISD::COND_SMSTOP", SDTypeProfile<0, 3, + [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, [SDNPHasChain, SDNPSideEffect, SDNPVariadic, SDNPOptInGlue, SDNPOutGlue]>; def AArch64_restore_za : SDNode<"AArch64ISD::RESTORE_ZA", SDTypeProfile<0, 3, @@ -305,15 +313,15 @@ def MSRpstatePseudo : let Defs = [VG]; } -def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition)), - (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition)>; -def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition)), - (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition)>; +def : Pat<(AArch64_cond_smstart (i32 svcr_op:$pstate), (i64 timm0_31:$condition), (i64 GPR64:$pstatesm)), + (MSRpstatePseudo svcr_op:$pstate, 0b1, timm0_31:$condition, GPR64:$pstatesm)>; +def : Pat<(AArch64_cond_smstop (i32 svcr_op:$pstate), (i64 timm0_31:$condition), (i64 GPR64:$pstatesm)), + (MSRpstatePseudo svcr_op:$pstate, 0b0, timm0_31:$condition, GPR64:$pstatesm)>; // Unconditional start/stop -def : Pat<(AArch64_smstart (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)), +def : Pat<(AArch64_smstart (i32 svcr_op:$pstate)), (MSRpstatesvcrImm1 svcr_op:$pstate, 0b1)>; -def : Pat<(AArch64_smstop (i32 svcr_op:$pstate), (i64 /*AArch64SME::Always*/0)), +def : Pat<(AArch64_smstop (i32 svcr_op:$pstate)), (MSRpstatesvcrImm1 svcr_op:$pstate, 0b0)>;