@@ -4040,47 +4040,48 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
4040
4040
SDValue AMDGPUTargetLowering::performShlCombine (SDNode *N,
4041
4041
DAGCombinerInfo &DCI) const {
4042
4042
EVT VT = N->getValueType (0 );
4043
-
4044
- ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
4045
- if (!RHS)
4046
- return SDValue ();
4047
-
4048
4043
SDValue LHS = N->getOperand (0 );
4049
- unsigned RHSVal = RHS->getZExtValue ();
4050
- if (!RHSVal)
4051
- return LHS;
4052
-
4044
+ SDValue RHS = N->getOperand (1 );
4045
+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
4053
4046
SDLoc SL (N);
4054
4047
SelectionDAG &DAG = DCI.DAG ;
4055
4048
4056
- switch (LHS->getOpcode ()) {
4057
- default :
4058
- break ;
4059
- case ISD::ZERO_EXTEND:
4060
- case ISD::SIGN_EXTEND:
4061
- case ISD::ANY_EXTEND: {
4062
- SDValue X = LHS->getOperand (0 );
4063
-
4064
- if (VT == MVT::i32 && RHSVal == 16 && X.getValueType () == MVT::i16 &&
4065
- isOperationLegal (ISD::BUILD_VECTOR, MVT::v2i16)) {
4066
- // Prefer build_vector as the canonical form if packed types are legal.
4067
- // (shl ([asz]ext i16:x), 16 -> build_vector 0, x
4068
- SDValue Vec = DAG.getBuildVector (MVT::v2i16, SL,
4069
- { DAG.getConstant (0 , SL, MVT::i16), LHS->getOperand (0 ) });
4070
- return DAG.getNode (ISD::BITCAST, SL, MVT::i32, Vec);
4071
- }
4049
+ unsigned RHSVal;
4050
+ if (CRHS) {
4051
+ RHSVal = CRHS->getZExtValue ();
4052
+ if (!RHSVal)
4053
+ return LHS;
4072
4054
4073
- // shl (ext x) => zext (shl x), if shift does not overflow int
4074
- if (VT != MVT::i64)
4075
- break ;
4076
- KnownBits Known = DAG.computeKnownBits (X);
4077
- unsigned LZ = Known.countMinLeadingZeros ();
4078
- if (LZ < RHSVal)
4055
+ switch (LHS->getOpcode ()) {
4056
+ default :
4079
4057
break ;
4080
- EVT XVT = X.getValueType ();
4081
- SDValue Shl = DAG.getNode (ISD::SHL, SL, XVT, X, SDValue (RHS, 0 ));
4082
- return DAG.getZExtOrTrunc (Shl, SL, VT);
4083
- }
4058
+ case ISD::ZERO_EXTEND:
4059
+ case ISD::SIGN_EXTEND:
4060
+ case ISD::ANY_EXTEND: {
4061
+ SDValue X = LHS->getOperand (0 );
4062
+
4063
+ if (VT == MVT::i32 && RHSVal == 16 && X.getValueType () == MVT::i16 &&
4064
+ isOperationLegal (ISD::BUILD_VECTOR, MVT::v2i16)) {
4065
+ // Prefer build_vector as the canonical form if packed types are legal.
4066
+ // (shl ([asz]ext i16:x), 16 -> build_vector 0, x
4067
+ SDValue Vec = DAG.getBuildVector (
4068
+ MVT::v2i16, SL,
4069
+ {DAG.getConstant (0 , SL, MVT::i16), LHS->getOperand (0 )});
4070
+ return DAG.getNode (ISD::BITCAST, SL, MVT::i32, Vec);
4071
+ }
4072
+
4073
+ // shl (ext x) => zext (shl x), if shift does not overflow int
4074
+ if (VT != MVT::i64)
4075
+ break ;
4076
+ KnownBits Known = DAG.computeKnownBits (X);
4077
+ unsigned LZ = Known.countMinLeadingZeros ();
4078
+ if (LZ < RHSVal)
4079
+ break ;
4080
+ EVT XVT = X.getValueType ();
4081
+ SDValue Shl = DAG.getNode (ISD::SHL, SL, XVT, X, SDValue (CRHS, 0 ));
4082
+ return DAG.getZExtOrTrunc (Shl, SL, VT);
4083
+ }
4084
+ }
4084
4085
}
4085
4086
4086
4087
if (VT != MVT::i64)
@@ -4091,18 +4092,34 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
4091
4092
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
4092
4093
// common case, splitting this into a move and a 32-bit shift is faster and
4093
4094
// the same code size.
4094
- if (RHSVal < 32 )
4095
+ EVT TargetType = VT.getHalfSizedIntegerVT (*DAG.getContext ());
4096
+ EVT TargetVecPairType = EVT::getVectorVT (*DAG.getContext (), TargetType, 2 );
4097
+ KnownBits Known = DAG.computeKnownBits (RHS);
4098
+
4099
+ if (Known.getMinValue ().getZExtValue () < TargetType.getSizeInBits ())
4095
4100
return SDValue ();
4101
+ SDValue ShiftAmt;
4096
4102
4097
- SDValue ShiftAmt = DAG.getConstant (RHSVal - 32 , SL, MVT::i32);
4103
+ if (CRHS) {
4104
+ ShiftAmt =
4105
+ DAG.getConstant (RHSVal - TargetType.getSizeInBits (), SL, TargetType);
4106
+ } else {
4107
+ SDValue truncShiftAmt = DAG.getNode (ISD::TRUNCATE, SL, TargetType, RHS);
4108
+ const SDValue ShiftMask =
4109
+ DAG.getConstant (TargetType.getSizeInBits () - 1 , SL, TargetType);
4110
+ // This AND instruction will clamp out of bounds shift values.
4111
+ // It will also be removed during later instruction selection.
4112
+ ShiftAmt = DAG.getNode (ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
4113
+ }
4098
4114
4099
- SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, MVT::i32, LHS);
4100
- SDValue NewShift = DAG.getNode (ISD::SHL, SL, MVT::i32, Lo, ShiftAmt);
4115
+ SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, TargetType, LHS);
4116
+ SDValue NewShift =
4117
+ DAG.getNode (ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags ());
4101
4118
4102
- const SDValue Zero = DAG.getConstant (0 , SL, MVT::i32 );
4119
+ const SDValue Zero = DAG.getConstant (0 , SL, TargetType );
4103
4120
4104
- SDValue Vec = DAG.getBuildVector (MVT::v2i32 , SL, {Zero, NewShift});
4105
- return DAG.getNode (ISD::BITCAST, SL, MVT::i64 , Vec);
4121
+ SDValue Vec = DAG.getBuildVector (TargetVecPairType , SL, {Zero, NewShift});
4122
+ return DAG.getNode (ISD::BITCAST, SL, VT , Vec);
4106
4123
}
4107
4124
4108
4125
SDValue AMDGPUTargetLowering::performSraCombine (SDNode *N,
0 commit comments