Skip to content

Commit 5decab1

Browse files
authored
AMDGPU: Reduce shl64 to shl32 if shift range is [63-32] (#125574)
Reduce: DST = shl i64 X, Y where Y is in the range [63-32] to: DST = [0, shl i32 X, (Y & 32)] Alive2 analysis: https://alive2.llvm.org/ce/z/w_u5je --------- Signed-off-by: John Lu <[email protected]>
1 parent 2bdeeaa commit 5decab1

File tree

3 files changed

+635
-42
lines changed

3 files changed

+635
-42
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

+59-42
Original file line numberDiff line numberDiff line change
@@ -4040,47 +4040,48 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
40404040
SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40414041
DAGCombinerInfo &DCI) const {
40424042
EVT VT = N->getValueType(0);
4043-
4044-
ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
4045-
if (!RHS)
4046-
return SDValue();
4047-
40484043
SDValue LHS = N->getOperand(0);
4049-
unsigned RHSVal = RHS->getZExtValue();
4050-
if (!RHSVal)
4051-
return LHS;
4052-
4044+
SDValue RHS = N->getOperand(1);
4045+
ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
40534046
SDLoc SL(N);
40544047
SelectionDAG &DAG = DCI.DAG;
40554048

4056-
switch (LHS->getOpcode()) {
4057-
default:
4058-
break;
4059-
case ISD::ZERO_EXTEND:
4060-
case ISD::SIGN_EXTEND:
4061-
case ISD::ANY_EXTEND: {
4062-
SDValue X = LHS->getOperand(0);
4063-
4064-
if (VT == MVT::i32 && RHSVal == 16 && X.getValueType() == MVT::i16 &&
4065-
isOperationLegal(ISD::BUILD_VECTOR, MVT::v2i16)) {
4066-
// Prefer build_vector as the canonical form if packed types are legal.
4067-
// (shl ([asz]ext i16:x), 16 -> build_vector 0, x
4068-
SDValue Vec = DAG.getBuildVector(MVT::v2i16, SL,
4069-
{ DAG.getConstant(0, SL, MVT::i16), LHS->getOperand(0) });
4070-
return DAG.getNode(ISD::BITCAST, SL, MVT::i32, Vec);
4071-
}
4049+
unsigned RHSVal;
4050+
if (CRHS) {
4051+
RHSVal = CRHS->getZExtValue();
4052+
if (!RHSVal)
4053+
return LHS;
40724054

4073-
// shl (ext x) => zext (shl x), if shift does not overflow int
4074-
if (VT != MVT::i64)
4075-
break;
4076-
KnownBits Known = DAG.computeKnownBits(X);
4077-
unsigned LZ = Known.countMinLeadingZeros();
4078-
if (LZ < RHSVal)
4055+
switch (LHS->getOpcode()) {
4056+
default:
40794057
break;
4080-
EVT XVT = X.getValueType();
4081-
SDValue Shl = DAG.getNode(ISD::SHL, SL, XVT, X, SDValue(RHS, 0));
4082-
return DAG.getZExtOrTrunc(Shl, SL, VT);
4083-
}
4058+
case ISD::ZERO_EXTEND:
4059+
case ISD::SIGN_EXTEND:
4060+
case ISD::ANY_EXTEND: {
4061+
SDValue X = LHS->getOperand(0);
4062+
4063+
if (VT == MVT::i32 && RHSVal == 16 && X.getValueType() == MVT::i16 &&
4064+
isOperationLegal(ISD::BUILD_VECTOR, MVT::v2i16)) {
4065+
// Prefer build_vector as the canonical form if packed types are legal.
4066+
// (shl ([asz]ext i16:x), 16 -> build_vector 0, x
4067+
SDValue Vec = DAG.getBuildVector(
4068+
MVT::v2i16, SL,
4069+
{DAG.getConstant(0, SL, MVT::i16), LHS->getOperand(0)});
4070+
return DAG.getNode(ISD::BITCAST, SL, MVT::i32, Vec);
4071+
}
4072+
4073+
// shl (ext x) => zext (shl x), if shift does not overflow int
4074+
if (VT != MVT::i64)
4075+
break;
4076+
KnownBits Known = DAG.computeKnownBits(X);
4077+
unsigned LZ = Known.countMinLeadingZeros();
4078+
if (LZ < RHSVal)
4079+
break;
4080+
EVT XVT = X.getValueType();
4081+
SDValue Shl = DAG.getNode(ISD::SHL, SL, XVT, X, SDValue(CRHS, 0));
4082+
return DAG.getZExtOrTrunc(Shl, SL, VT);
4083+
}
4084+
}
40844085
}
40854086

40864087
if (VT != MVT::i64)
@@ -4091,18 +4092,34 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
40914092
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
40924093
// common case, splitting this into a move and a 32-bit shift is faster and
40934094
// the same code size.
4094-
if (RHSVal < 32)
4095+
EVT TargetType = VT.getHalfSizedIntegerVT(*DAG.getContext());
4096+
EVT TargetVecPairType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
4097+
KnownBits Known = DAG.computeKnownBits(RHS);
4098+
4099+
if (Known.getMinValue().getZExtValue() < TargetType.getSizeInBits())
40954100
return SDValue();
4101+
SDValue ShiftAmt;
40964102

4097-
SDValue ShiftAmt = DAG.getConstant(RHSVal - 32, SL, MVT::i32);
4103+
if (CRHS) {
4104+
ShiftAmt =
4105+
DAG.getConstant(RHSVal - TargetType.getSizeInBits(), SL, TargetType);
4106+
} else {
4107+
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
4108+
const SDValue ShiftMask =
4109+
DAG.getConstant(TargetType.getSizeInBits() - 1, SL, TargetType);
4110+
// This AND instruction will clamp out of bounds shift values.
4111+
// It will also be removed during later instruction selection.
4112+
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4113+
}
40984114

4099-
SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, LHS);
4100-
SDValue NewShift = DAG.getNode(ISD::SHL, SL, MVT::i32, Lo, ShiftAmt);
4115+
SDValue Lo = DAG.getNode(ISD::TRUNCATE, SL, TargetType, LHS);
4116+
SDValue NewShift =
4117+
DAG.getNode(ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags());
41014118

4102-
const SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
4119+
const SDValue Zero = DAG.getConstant(0, SL, TargetType);
41034120

4104-
SDValue Vec = DAG.getBuildVector(MVT::v2i32, SL, {Zero, NewShift});
4105-
return DAG.getNode(ISD::BITCAST, SL, MVT::i64, Vec);
4121+
SDValue Vec = DAG.getBuildVector(TargetVecPairType, SL, {Zero, NewShift});
4122+
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
41064123
}
41074124

41084125
SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,

0 commit comments

Comments
 (0)