Skip to content

Commit 8c8ff30

Browse files
committed
[LoongArch] Pass 'half' in the lower 16 bits of an f32 value when F extension is enabled
LoongArch currently lacks a hardware extension for the fp16 data type, and the ABI documentation does not explicitly define how to handle fp16. Future revsions of the LoongArch specification will include conventions to address fp16 requirements. Previously, we maintained the 'half' type in its 16-bit format between operations. Regardless of whether the F extension is enabled, the value would be passed in the lower 16 bits of a GPR in its 'half' format. With this patch, depending on the ABI in use, the value will be passed either in an FPR or a GPR in 'half' format. This ensures consistency with the bits location when the fp16 hardware extension is enabled.
1 parent b5cdb03 commit 8c8ff30

File tree

4 files changed

+1214
-71
lines changed

4 files changed

+1214
-71
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

+140-5
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
181181
setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
182182
setOperationAction(ISD::FPOW, MVT::f32, Expand);
183183
setOperationAction(ISD::FREM, MVT::f32, Expand);
184-
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
185-
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
184+
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
185+
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
186186

187187
if (Subtarget.is64Bit())
188188
setOperationAction(ISD::FRINT, MVT::f32, Legal);
@@ -219,7 +219,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
219219
setOperationAction(ISD::FPOW, MVT::f64, Expand);
220220
setOperationAction(ISD::FREM, MVT::f64, Expand);
221221
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
222-
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
222+
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
223223

224224
if (Subtarget.is64Bit())
225225
setOperationAction(ISD::FRINT, MVT::f64, Legal);
@@ -427,6 +427,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
427427
return lowerBUILD_VECTOR(Op, DAG);
428428
case ISD::VECTOR_SHUFFLE:
429429
return lowerVECTOR_SHUFFLE(Op, DAG);
430+
case ISD::FP_TO_FP16:
431+
return lowerFP_TO_FP16(Op, DAG);
432+
case ISD::FP16_TO_FP:
433+
return lowerFP16_TO_FP(Op, DAG);
430434
}
431435
return SDValue();
432436
}
@@ -1354,6 +1358,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
13541358
return SDValue();
13551359
}
13561360

1361+
SDValue LoongArchTargetLowering::lowerFP_TO_FP16(SDValue Op,
1362+
SelectionDAG &DAG) const {
1363+
// Custom lower to ensure the libcall return is passed in an FPR on hard
1364+
// float ABIs.
1365+
SDLoc DL(Op);
1366+
MakeLibCallOptions CallOptions;
1367+
SDValue Op0 = Op.getOperand(0);
1368+
SDValue Chain = SDValue();
1369+
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
1370+
SDValue Res;
1371+
std::tie(Res, Chain) =
1372+
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
1373+
if (Subtarget.is64Bit())
1374+
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
1375+
return DAG.getBitcast(MVT::i32, Res);
1376+
}
1377+
1378+
SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
1379+
SelectionDAG &DAG) const {
1380+
// Custom lower to ensure the libcall argument is passed in an FPR on hard
1381+
// float ABIs.
1382+
SDLoc DL(Op);
1383+
MakeLibCallOptions CallOptions;
1384+
SDValue Op0 = Op.getOperand(0);
1385+
SDValue Chain = SDValue();
1386+
SDValue Arg = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
1387+
DL, MVT::f32, Op0)
1388+
: DAG.getBitcast(MVT::f32, Op0);
1389+
SDValue Res;
1390+
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
1391+
CallOptions, DL, Chain);
1392+
return Res;
1393+
}
1394+
13571395
static bool isConstantOrUndef(const SDValue Op) {
13581396
if (Op->isUndef())
13591397
return true;
@@ -1656,16 +1694,20 @@ SDValue LoongArchTargetLowering::lowerFP_TO_SINT(SDValue Op,
16561694
SelectionDAG &DAG) const {
16571695

16581696
SDLoc DL(Op);
1697+
SDValue Op0 = Op.getOperand(0);
1698+
1699+
if (Op0.getValueType() == MVT::f16)
1700+
Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
16591701

16601702
if (Op.getValueSizeInBits() > 32 && Subtarget.hasBasicF() &&
16611703
!Subtarget.hasBasicD()) {
16621704
SDValue Dst =
1663-
DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op.getOperand(0));
1705+
DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op0);
16641706
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Dst);
16651707
}
16661708

16671709
EVT FPTy = EVT::getFloatingPointVT(Op.getValueSizeInBits());
1668-
SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op.getOperand(0));
1710+
SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op0);
16691711
return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Trunc);
16701712
}
16711713

@@ -2848,6 +2890,10 @@ void LoongArchTargetLowering::ReplaceNodeResults(
28482890
EVT FVT = EVT::getFloatingPointVT(N->getValueSizeInBits(0));
28492891
if (getTypeAction(*DAG.getContext(), Src.getValueType()) !=
28502892
TargetLowering::TypeSoftenFloat) {
2893+
if (!isTypeLegal(Src.getValueType()))
2894+
return;
2895+
if (Src.getValueType() == MVT::f16)
2896+
Src = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
28512897
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, FVT, Src);
28522898
Results.push_back(DAG.getNode(ISD::BITCAST, DL, VT, Dst));
28532899
return;
@@ -4229,6 +4275,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
42294275
return SDValue();
42304276
}
42314277

4278+
static SDValue performMOVGR2FR_WCombine(SDNode *N, SelectionDAG &DAG,
4279+
TargetLowering::DAGCombinerInfo &DCI,
4280+
const LoongArchSubtarget &Subtarget) {
4281+
// If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
4282+
// conversion is unnecessary and can be replaced with the
4283+
// MOVFR2GR_S_LA64 operand.
4284+
SDValue Op0 = N->getOperand(0);
4285+
if (Op0.getOpcode() == LoongArchISD::MOVFR2GR_S_LA64)
4286+
return Op0.getOperand(0);
4287+
return SDValue();
4288+
}
4289+
4290+
static SDValue performMOVFR2GR_SCombine(SDNode *N, SelectionDAG &DAG,
4291+
TargetLowering::DAGCombinerInfo &DCI,
4292+
const LoongArchSubtarget &Subtarget) {
4293+
// If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
4294+
// conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
4295+
// operand.
4296+
SDValue Op0 = N->getOperand(0);
4297+
MVT VT = N->getSimpleValueType(0);
4298+
if (Op0->getOpcode() == LoongArchISD::MOVGR2FR_W_LA64) {
4299+
assert(Op0.getOperand(0).getValueType() == VT && "Unexpected value type!");
4300+
return Op0.getOperand(0);
4301+
}
4302+
return SDValue();
4303+
}
4304+
42324305
SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
42334306
DAGCombinerInfo &DCI) const {
42344307
SelectionDAG &DAG = DCI.DAG;
@@ -4247,6 +4320,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
42474320
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
42484321
case ISD::INTRINSIC_WO_CHAIN:
42494322
return performINTRINSIC_WO_CHAINCombine(N, DAG, DCI, Subtarget);
4323+
case LoongArchISD::MOVGR2FR_W_LA64:
4324+
return performMOVGR2FR_WCombine(N, DAG, DCI, Subtarget);
4325+
case LoongArchISD::MOVFR2GR_S_LA64:
4326+
return performMOVFR2GR_SCombine(N, DAG, DCI, Subtarget);
42504327
}
42514328
return SDValue();
42524329
}
@@ -6260,3 +6337,61 @@ bool LoongArchTargetLowering::shouldAlignPointerArgs(CallInst *CI,
62606337

62616338
return true;
62626339
}
6340+
6341+
bool LoongArchTargetLowering::splitValueIntoRegisterParts(
6342+
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
6343+
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
6344+
bool IsABIRegCopy = CC.has_value();
6345+
EVT ValueVT = Val.getValueType();
6346+
6347+
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
6348+
// Cast the f16 to i16, extend to i32, pad with ones to make a float
6349+
// nan, and cast to f32.
6350+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
6351+
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
6352+
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
6353+
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
6354+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
6355+
Parts[0] = Val;
6356+
return true;
6357+
}
6358+
6359+
return false;
6360+
}
6361+
6362+
SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
6363+
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
6364+
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
6365+
bool IsABIRegCopy = CC.has_value();
6366+
6367+
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
6368+
SDValue Val = Parts[0];
6369+
6370+
// Cast the f32 to i32, truncate to i16, and cast back to f16.
6371+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
6372+
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
6373+
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
6374+
return Val;
6375+
}
6376+
6377+
return SDValue();
6378+
}
6379+
6380+
MVT LoongArchTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
6381+
CallingConv::ID CC,
6382+
EVT VT) const {
6383+
// Use f32 to pass f16.
6384+
if (VT == MVT::f16 && Subtarget.hasBasicF())
6385+
return MVT::f32;
6386+
6387+
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
6388+
}
6389+
6390+
unsigned LoongArchTargetLowering::getNumRegistersForCallingConv(
6391+
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
6392+
// Use f32 to pass f16.
6393+
if (VT == MVT::f16 && Subtarget.hasBasicF())
6394+
return 1;
6395+
6396+
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
6397+
}

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

+24
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ class LoongArchTargetLowering : public TargetLowering {
315315
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
316316
SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
317317
SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
318+
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
319+
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
318320

319321
bool isFPImmLegal(const APFloat &Imm, EVT VT,
320322
bool ForCodeSize) const override;
@@ -339,6 +341,28 @@ class LoongArchTargetLowering : public TargetLowering {
339341
const SmallVectorImpl<CCValAssign> &ArgLocs) const;
340342

341343
bool softPromoteHalfType() const override { return true; }
344+
345+
bool
346+
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
347+
SDValue *Parts, unsigned NumParts, MVT PartVT,
348+
std::optional<CallingConv::ID> CC) const override;
349+
350+
SDValue
351+
joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
352+
const SDValue *Parts, unsigned NumParts,
353+
MVT PartVT, EVT ValueVT,
354+
std::optional<CallingConv::ID> CC) const override;
355+
356+
/// Return the register type for a given MVT, ensuring vectors are treated
357+
/// as a series of gpr sized integers.
358+
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
359+
EVT VT) const override;
360+
361+
/// Return the number of registers for a given MVT, ensuring vectors are
362+
/// treated as a series of gpr sized integers.
363+
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
364+
CallingConv::ID CC,
365+
EVT VT) const override;
342366
};
343367

344368
} // end namespace llvm

0 commit comments

Comments
 (0)