Skip to content

Commit 235cada

Browse files
committed
[LoongArch] Pass 'half' in the lower 16 bits of an f32 value when F extension is enabled
LoongArch currently lacks a hardware extension for the fp16 data type, and the ABI documentation does not explicitly define how to handle fp16. Future revsions of the LoongArch specification will include conventions to address fp16 requirements. Previously, we maintained the 'half' type in its 16-bit format between operations. Regardless of whether the F extension is enabled, the value would be passed in the lower 16 bits of a GPR in its 'half' format. With this patch, depending on the ABI in use, the value will be passed either in an FPR or a GPR in 'half' format. This ensures consistency with the bits location when the fp16 hardware extension is enabled.
1 parent b5cdb03 commit 235cada

File tree

4 files changed

+1214
-72
lines changed

4 files changed

+1214
-72
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

+140-6
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
181181
setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
182182
setOperationAction(ISD::FPOW, MVT::f32, Expand);
183183
setOperationAction(ISD::FREM, MVT::f32, Expand);
184-
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Expand);
185-
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
184+
setOperationAction(ISD::FP16_TO_FP, MVT::f32, Custom);
185+
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Custom);
186186

187187
if (Subtarget.is64Bit())
188188
setOperationAction(ISD::FRINT, MVT::f32, Legal);
@@ -219,7 +219,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
219219
setOperationAction(ISD::FPOW, MVT::f64, Expand);
220220
setOperationAction(ISD::FREM, MVT::f64, Expand);
221221
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
222-
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
222+
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
223223

224224
if (Subtarget.is64Bit())
225225
setOperationAction(ISD::FRINT, MVT::f64, Legal);
@@ -427,6 +427,10 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
427427
return lowerBUILD_VECTOR(Op, DAG);
428428
case ISD::VECTOR_SHUFFLE:
429429
return lowerVECTOR_SHUFFLE(Op, DAG);
430+
case ISD::FP_TO_FP16:
431+
return lowerFP_TO_FP16(Op, DAG);
432+
case ISD::FP16_TO_FP:
433+
return lowerFP16_TO_FP(Op, DAG);
430434
}
431435
return SDValue();
432436
}
@@ -1354,6 +1358,40 @@ SDValue LoongArchTargetLowering::lowerVECTOR_SHUFFLE(SDValue Op,
13541358
return SDValue();
13551359
}
13561360

1361+
SDValue LoongArchTargetLowering::lowerFP_TO_FP16(SDValue Op,
1362+
SelectionDAG &DAG) const {
1363+
// Custom lower to ensure the libcall return is passed in an FPR on hard
1364+
// float ABIs.
1365+
SDLoc DL(Op);
1366+
MakeLibCallOptions CallOptions;
1367+
SDValue Op0 = Op.getOperand(0);
1368+
SDValue Chain = SDValue();
1369+
RTLIB::Libcall LC = RTLIB::getFPROUND(Op0.getValueType(), MVT::f16);
1370+
SDValue Res;
1371+
std::tie(Res, Chain) =
1372+
makeLibCall(DAG, LC, MVT::f32, Op0, CallOptions, DL, Chain);
1373+
if (Subtarget.is64Bit())
1374+
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Res);
1375+
return DAG.getBitcast(MVT::i32, Res);
1376+
}
1377+
1378+
SDValue LoongArchTargetLowering::lowerFP16_TO_FP(SDValue Op,
1379+
SelectionDAG &DAG) const {
1380+
// Custom lower to ensure the libcall argument is passed in an FPR on hard
1381+
// float ABIs.
1382+
SDLoc DL(Op);
1383+
MakeLibCallOptions CallOptions;
1384+
SDValue Op0 = Op.getOperand(0);
1385+
SDValue Chain = SDValue();
1386+
SDValue Arg = Subtarget.is64Bit() ? DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64,
1387+
DL, MVT::f32, Op0)
1388+
: DAG.getBitcast(MVT::f32, Op0);
1389+
SDValue Res;
1390+
std::tie(Res, Chain) = makeLibCall(DAG, RTLIB::FPEXT_F16_F32, MVT::f32, Arg,
1391+
CallOptions, DL, Chain);
1392+
return Res;
1393+
}
1394+
13571395
static bool isConstantOrUndef(const SDValue Op) {
13581396
if (Op->isUndef())
13591397
return true;
@@ -1656,16 +1694,19 @@ SDValue LoongArchTargetLowering::lowerFP_TO_SINT(SDValue Op,
16561694
SelectionDAG &DAG) const {
16571695

16581696
SDLoc DL(Op);
1697+
SDValue Op0 = Op.getOperand(0);
1698+
1699+
if (Op0.getValueType() == MVT::f16)
1700+
Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
16591701

16601702
if (Op.getValueSizeInBits() > 32 && Subtarget.hasBasicF() &&
16611703
!Subtarget.hasBasicD()) {
1662-
SDValue Dst =
1663-
DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op.getOperand(0));
1704+
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, MVT::f32, Op0);
16641705
return DAG.getNode(LoongArchISD::MOVFR2GR_S_LA64, DL, MVT::i64, Dst);
16651706
}
16661707

16671708
EVT FPTy = EVT::getFloatingPointVT(Op.getValueSizeInBits());
1668-
SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op.getOperand(0));
1709+
SDValue Trunc = DAG.getNode(LoongArchISD::FTINT, DL, FPTy, Op0);
16691710
return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Trunc);
16701711
}
16711712

@@ -2848,6 +2889,10 @@ void LoongArchTargetLowering::ReplaceNodeResults(
28482889
EVT FVT = EVT::getFloatingPointVT(N->getValueSizeInBits(0));
28492890
if (getTypeAction(*DAG.getContext(), Src.getValueType()) !=
28502891
TargetLowering::TypeSoftenFloat) {
2892+
if (!isTypeLegal(Src.getValueType()))
2893+
return;
2894+
if (Src.getValueType() == MVT::f16)
2895+
Src = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
28512896
SDValue Dst = DAG.getNode(LoongArchISD::FTINT, DL, FVT, Src);
28522897
Results.push_back(DAG.getNode(ISD::BITCAST, DL, VT, Dst));
28532898
return;
@@ -4229,6 +4274,33 @@ performINTRINSIC_WO_CHAINCombine(SDNode *N, SelectionDAG &DAG,
42294274
return SDValue();
42304275
}
42314276

4277+
static SDValue performMOVGR2FR_WCombine(SDNode *N, SelectionDAG &DAG,
4278+
TargetLowering::DAGCombinerInfo &DCI,
4279+
const LoongArchSubtarget &Subtarget) {
4280+
// If the input to MOVGR2FR_W_LA64 is just MOVFR2GR_S_LA64 the the
4281+
// conversion is unnecessary and can be replaced with the
4282+
// MOVFR2GR_S_LA64 operand.
4283+
SDValue Op0 = N->getOperand(0);
4284+
if (Op0.getOpcode() == LoongArchISD::MOVFR2GR_S_LA64)
4285+
return Op0.getOperand(0);
4286+
return SDValue();
4287+
}
4288+
4289+
static SDValue performMOVFR2GR_SCombine(SDNode *N, SelectionDAG &DAG,
4290+
TargetLowering::DAGCombinerInfo &DCI,
4291+
const LoongArchSubtarget &Subtarget) {
4292+
// If the input to MOVFR2GR_S_LA64 is just MOVGR2FR_W_LA64 then the
4293+
// conversion is unnecessary and can be replaced with the MOVGR2FR_W_LA64
4294+
// operand.
4295+
SDValue Op0 = N->getOperand(0);
4296+
MVT VT = N->getSimpleValueType(0);
4297+
if (Op0->getOpcode() == LoongArchISD::MOVGR2FR_W_LA64) {
4298+
assert(Op0.getOperand(0).getValueType() == VT && "Unexpected value type!");
4299+
return Op0.getOperand(0);
4300+
}
4301+
return SDValue();
4302+
}
4303+
42324304
SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
42334305
DAGCombinerInfo &DCI) const {
42344306
SelectionDAG &DAG = DCI.DAG;
@@ -4247,6 +4319,10 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
42474319
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
42484320
case ISD::INTRINSIC_WO_CHAIN:
42494321
return performINTRINSIC_WO_CHAINCombine(N, DAG, DCI, Subtarget);
4322+
case LoongArchISD::MOVGR2FR_W_LA64:
4323+
return performMOVGR2FR_WCombine(N, DAG, DCI, Subtarget);
4324+
case LoongArchISD::MOVFR2GR_S_LA64:
4325+
return performMOVFR2GR_SCombine(N, DAG, DCI, Subtarget);
42504326
}
42514327
return SDValue();
42524328
}
@@ -6260,3 +6336,61 @@ bool LoongArchTargetLowering::shouldAlignPointerArgs(CallInst *CI,
62606336

62616337
return true;
62626338
}
6339+
6340+
bool LoongArchTargetLowering::splitValueIntoRegisterParts(
6341+
SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
6342+
unsigned NumParts, MVT PartVT, std::optional<CallingConv::ID> CC) const {
6343+
bool IsABIRegCopy = CC.has_value();
6344+
EVT ValueVT = Val.getValueType();
6345+
6346+
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
6347+
// Cast the f16 to i16, extend to i32, pad with ones to make a float
6348+
// nan, and cast to f32.
6349+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i16, Val);
6350+
Val = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Val);
6351+
Val = DAG.getNode(ISD::OR, DL, MVT::i32, Val,
6352+
DAG.getConstant(0xFFFF0000, DL, MVT::i32));
6353+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Val);
6354+
Parts[0] = Val;
6355+
return true;
6356+
}
6357+
6358+
return false;
6359+
}
6360+
6361+
SDValue LoongArchTargetLowering::joinRegisterPartsIntoValue(
6362+
SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts,
6363+
MVT PartVT, EVT ValueVT, std::optional<CallingConv::ID> CC) const {
6364+
bool IsABIRegCopy = CC.has_value();
6365+
6366+
if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) {
6367+
SDValue Val = Parts[0];
6368+
6369+
// Cast the f32 to i32, truncate to i16, and cast back to f16.
6370+
Val = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Val);
6371+
Val = DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Val);
6372+
Val = DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
6373+
return Val;
6374+
}
6375+
6376+
return SDValue();
6377+
}
6378+
6379+
MVT LoongArchTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
6380+
CallingConv::ID CC,
6381+
EVT VT) const {
6382+
// Use f32 to pass f16.
6383+
if (VT == MVT::f16 && Subtarget.hasBasicF())
6384+
return MVT::f32;
6385+
6386+
return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
6387+
}
6388+
6389+
unsigned LoongArchTargetLowering::getNumRegistersForCallingConv(
6390+
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
6391+
// Use f32 to pass f16.
6392+
if (VT == MVT::f16 && Subtarget.hasBasicF())
6393+
return 1;
6394+
6395+
return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
6396+
}

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

+24
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ class LoongArchTargetLowering : public TargetLowering {
315315
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
316316
SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
317317
SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
318+
SDValue lowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
319+
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
318320

319321
bool isFPImmLegal(const APFloat &Imm, EVT VT,
320322
bool ForCodeSize) const override;
@@ -339,6 +341,28 @@ class LoongArchTargetLowering : public TargetLowering {
339341
const SmallVectorImpl<CCValAssign> &ArgLocs) const;
340342

341343
bool softPromoteHalfType() const override { return true; }
344+
345+
bool
346+
splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
347+
SDValue *Parts, unsigned NumParts, MVT PartVT,
348+
std::optional<CallingConv::ID> CC) const override;
349+
350+
SDValue
351+
joinRegisterPartsIntoValue(SelectionDAG &DAG, const SDLoc &DL,
352+
const SDValue *Parts, unsigned NumParts,
353+
MVT PartVT, EVT ValueVT,
354+
std::optional<CallingConv::ID> CC) const override;
355+
356+
/// Return the register type for a given MVT, ensuring vectors are treated
357+
/// as a series of gpr sized integers.
358+
MVT getRegisterTypeForCallingConv(LLVMContext &Context, CallingConv::ID CC,
359+
EVT VT) const override;
360+
361+
/// Return the number of registers for a given MVT, ensuring vectors are
362+
/// treated as a series of gpr sized integers.
363+
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
364+
CallingConv::ID CC,
365+
EVT VT) const override;
342366
};
343367

344368
} // end namespace llvm

0 commit comments

Comments
 (0)