@@ -181,8 +181,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
181181 setOperationAction (ISD::FSINCOS, MVT::f32 , Expand);
182182 setOperationAction (ISD::FPOW, MVT::f32 , Expand);
183183 setOperationAction (ISD::FREM, MVT::f32 , Expand);
184- setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Expand );
185- setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Expand );
184+ setOperationAction (ISD::FP16_TO_FP, MVT::f32 , Custom );
185+ setOperationAction (ISD::FP_TO_FP16, MVT::f32 , Custom );
186186
187187 if (Subtarget.is64Bit ())
188188 setOperationAction (ISD::FRINT, MVT::f32 , Legal);
@@ -219,7 +219,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
219219 setOperationAction (ISD::FPOW, MVT::f64 , Expand);
220220 setOperationAction (ISD::FREM, MVT::f64 , Expand);
221221 setOperationAction (ISD::FP16_TO_FP, MVT::f64 , Expand);
222- setOperationAction (ISD::FP_TO_FP16, MVT::f64 , Expand );
222+ setOperationAction (ISD::FP_TO_FP16, MVT::f64 , Custom );
223223
224224 if (Subtarget.is64Bit ())
225225 setOperationAction (ISD::FRINT, MVT::f64 , Legal);
@@ -427,6 +427,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
427427 return lowerBUILD_VECTOR (Op, DAG);
428428 case ISD::VECTOR_SHUFFLE:
429429 return lowerVECTOR_SHUFFLE (Op, DAG);
430+ case ISD::FP_TO_FP16:
431+ return lowerFP_TO_FP16 (Op, DAG);
432+ case ISD::FP16_TO_FP:
433+ return lowerFP16_TO_FP (Op, DAG);
430434 }
431435 return SDValue ();
432436}
@@ -1354,6 +1358,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
13541358 return SDValue ();
13551359}
13561360
1361+ SDValue LoongArchTargetLowering::lowerFP_TO_FP16 (SDValue Op,
1362+ SelectionDAG &DAG) const {
1363+ // Custom lower to ensure the libcall return is passed in an FPR on hard
1364+ // float ABIs.
1365+ SDLoc DL (Op);
1366+ MakeLibCallOptions CallOptions;
1367+ SDValue Op0 = Op.getOperand (0 );
1368+ SDValue Chain = SDValue ();
1369+ RTLIB::Libcall LC = RTLIB::getFPROUND (Op0.getValueType (), MVT::f16 );
1370+ SDValue Res;
1371+ std::tie (Res, Chain) =
1372+ makeLibCall (DAG, LC, MVT::f32 , Op0, CallOptions, DL, Chain);
1373+ if (Subtarget.is64Bit ())
1374+ return DAG.getNode (LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64 , Res);
1375+ return DAG.getBitcast (MVT::i32 , Res);
1376+ }
1377+
1378+ SDValue LoongArchTargetLowering::lowerFP16_TO_FP (SDValue Op,
1379+ SelectionDAG &DAG) const {
1380+ // Custom lower to ensure the libcall argument is passed in an FPR on hard
1381+ // float ABIs.
1382+ SDLoc DL (Op);
1383+ MakeLibCallOptions CallOptions;
1384+ SDValue Op0 = Op.getOperand (0 );
1385+ SDValue Chain = SDValue ();
1386+ SDValue Arg = Subtarget.is64Bit () ? DAG.getNode (LoongArchISD::MOVGR2FR_W_LA64,
1387+ DL, MVT::f32 , Op0)
1388+ : DAG.getBitcast (MVT::f32 , Op0);
1389+ SDValue Res;
1390+ std::tie (Res, Chain) = makeLibCall (DAG, RTLIB::FPEXT_F16_F32, MVT::f32 , Arg,
1391+ CallOptions, DL, Chain);
1392+ return Res;
1393+ }
1394+
13571395static bool isConstantOrUndef (const SDValue Op) {
13581396 if (Op->isUndef ())
13591397 return true ;
@@ -1656,16 +1694,20 @@ SDValue LoongArchTargetLowering::lowerFP_TO_SINT(SDValue Op,
16561694 SelectionDAG &DAG) const {
16571695
16581696 SDLoc DL (Op);
1697+ SDValue Op0 = Op.getOperand (0 );
1698+
1699+ if (Op0.getValueType () == MVT::f16 )
1700+ Op0 = DAG.getNode (ISD::FP_EXTEND, DL, MVT::f32 , Op0);
16591701
16601702 if (Op.getValueSizeInBits () > 32 && Subtarget.hasBasicF () &&
16611703 !Subtarget.hasBasicD ()) {
16621704 SDValue Dst =
1663- DAG.getNode (LoongArchISD::FTINT, DL, MVT::f32 , Op. getOperand ( 0 ) );
1705+ DAG.getNode (LoongArchISD::FTINT, DL, MVT::f32 , Op0 );
16641706 return DAG.getNode (LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64 , Dst);
16651707 }
16661708
16671709 EVT FPTy = EVT::getFloatingPointVT (Op.getValueSizeInBits ());
1668- SDValue Trunc = DAG.getNode (LoongArchISD::FTINT, DL, FPTy, Op. getOperand ( 0 ) );
1710+ SDValue Trunc = DAG.getNode (LoongArchISD::FTINT, DL, FPTy, Op0 );
16691711 return DAG.getNode (ISD::BITCAST, DL, Op.getValueType (), Trunc);
16701712}
16711713
@@ -2848,6 +2890,10 @@ void LoongArchTargetLowering::ReplaceNodeResults(
28482890 EVT FVT = EVT::getFloatingPointVT (N->getValueSizeInBits (0 ));
28492891 if (getTypeAction (*DAG.getContext (), Src.getValueType ()) !=
28502892 TargetLowering::TypeSoftenFloat) {
2893+ if (!isTypeLegal (Src.getValueType ()))
2894+ return ;
2895+ if (Src.getValueType () == MVT::f16 )
2896+ Src = DAG.getNode (ISD::FP_EXTEND, DL, MVT::f32 , Src);
28512897 SDValue Dst = DAG.getNode (LoongArchISD::FTINT, DL, FVT, Src);
28522898 Results.push_back (DAG.getNode (ISD::BITCAST, DL, VT, Dst));
28532899 return ;
@@ -4229,6 +4275,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
42294275 return SDValue ();
42304276}
42314277
4278+ static SDValue performMOVGR2FR_WCombine (SDNode *N, SelectionDAG &DAG,
4279+ TargetLowering::DAGCombinerInfo &DCI,
4280+ const LoongArchSubtarget &Subtarget) {
4281+ // If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
4282+ // conversion is unnecessary and can be replaced with the
4283+ // MOVFR2GR_S_LA64 operand.
4284+ SDValue Op0 = N->getOperand (0 );
4285+ if (Op0.getOpcode () == LoongArchISD::MOVFR2GR_S_LA64)
4286+ return Op0.getOperand (0 );
4287+ return SDValue ();
4288+ }
4289+
4290+ static SDValue performMOVFR2GR_SCombine (SDNode *N, SelectionDAG &DAG,
4291+ TargetLowering::DAGCombinerInfo &DCI,
4292+ const LoongArchSubtarget &Subtarget) {
4293+ // If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
4294+ // conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
4295+ // operand.
4296+ SDValue Op0 = N->getOperand (0 );
4297+ MVT VT = N->getSimpleValueType (0 );
4298+ if (Op0->getOpcode () == LoongArchISD::MOVGR2FR_W_LA64) {
4299+ assert (Op0.getOperand (0 ).getValueType () == VT && " Unexpected value type!" );
4300+ return Op0.getOperand (0 );
4301+ }
4302+ return SDValue ();
4303+ }
4304+
42324305SDValue LoongArchTargetLowering::PerformDAGCombine (SDNode *N,
42334306 DAGCombinerInfo &DCI) const {
42344307 SelectionDAG &DAG = DCI.DAG ;
@@ -4247,6 +4320,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
42474320 return performBITREV_WCombine (N, DAG, DCI, Subtarget);
42484321 case ISD::INTRINSIC_WO_CHAIN:
42494322 return performINTRINSIC_WO_CHAINCombine (N, DAG, DCI, Subtarget);
4323+ case LoongArchISD::MOVGR2FR_W_LA64:
4324+ return performMOVGR2FR_WCombine (N, DAG, DCI, Subtarget);
4325+ case LoongArchISD::MOVFR2GR_S_LA64:
4326+ return performMOVFR2GR_SCombine (N, DAG, DCI, Subtarget);
42504327 }
42514328 return SDValue ();
42524329}
@@ -6260,3 +6337,61 @@ bool LoongArchTargetLowering::shouldAlignPointerArgs(CallInst *CI,
62606337
62616338 return true ;
62626339}
6340+
6341+ bool LoongArchTargetLowering::splitValueIntoRegisterParts (
6342+ SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
6343+ unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
6344+ bool IsABIRegCopy = CC.has_value ();
6345+ EVT ValueVT = Val.getValueType ();
6346+
6347+ if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32 ) {
6348+ // Cast the f16 to i16, extend to i32, pad with ones to make a float
6349+ // nan, and cast to f32.
6350+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::i16 , Val);
6351+ Val = DAG.getNode (ISD::ANY_EXTEND, DL, MVT::i32 , Val);
6352+ Val = DAG.getNode (ISD::OR, DL, MVT::i32 , Val,
6353+ DAG.getConstant (0xFFFF0000 , DL, MVT::i32 ));
6354+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::f32 , Val);
6355+ Parts[0 ] = Val;
6356+ return true ;
6357+ }
6358+
6359+ return false ;
6360+ }
6361+
6362+ SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue (
6363+ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
6364+ MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
6365+ bool IsABIRegCopy = CC.has_value ();
6366+
6367+ if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32 ) {
6368+ SDValue Val = Parts[0 ];
6369+
6370+ // Cast the f32 to i32, truncate to i16, and cast back to f16.
6371+ Val = DAG.getNode (ISD::BITCAST, DL, MVT::i32 , Val);
6372+ Val = DAG.getNode (ISD::TRUNCATE, DL, MVT::i16 , Val);
6373+ Val = DAG.getNode (ISD::BITCAST, DL, ValueVT, Val);
6374+ return Val;
6375+ }
6376+
6377+ return SDValue ();
6378+ }
6379+
6380+ MVT LoongArchTargetLowering::getRegisterTypeForCallingConv (LLVMContext &Context,
6381+ CallingConv::ID CC,
6382+ EVT VT) const {
6383+ // Use f32 to pass f16.
6384+ if (VT == MVT::f16 && Subtarget.hasBasicF ())
6385+ return MVT::f32 ;
6386+
6387+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC, VT);
6388+ }
6389+
6390+ unsigned LoongArchTargetLowering::getNumRegistersForCallingConv (
6391+ LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
6392+ // Use f32 to pass f16.
6393+ if (VT == MVT::f16 && Subtarget.hasBasicF ())
6394+ return 1 ;
6395+
6396+ return TargetLowering::getNumRegistersForCallingConv (Context, CC, VT);
6397+ }
0 commit comments