Skip to content

[InstCombine] Fix infinite loop due to bitcast <-> phi transforms #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 21, 2020
73 changes: 55 additions & 18 deletions llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,31 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
}
}

// Check that each user of each old PHI node is something that we can
// rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
for (auto *OldPN : OldPhiNodes) {
for (User *V : OldPN->users()) {
if (auto *SI = dyn_cast<StoreInst>(V)) {
if (!SI->isSimple() || SI->getOperand(0) != OldPN)
return nullptr;
} else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
// Verify it's a B->A cast.
Type *TyB = BCI->getOperand(0)->getType();
Type *TyA = BCI->getType();
if (TyA != DestTy || TyB != SrcTy)
return nullptr;
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
// As long as the user is another old PHI node, then even if we don't
// rewrite it, the PHI web we're considering won't have any users
// outside itself, so it'll be dead.
if (OldPhiNodes.count(PHI) == 0)
return nullptr;
} else {
return nullptr;
}
}
}

// For each old PHI node, create a corresponding new PHI node with a type A.
SmallDenseMap<PHINode *, PHINode *> NewPNodes;
for (auto *OldPN : OldPhiNodes) {
Expand All @@ -2234,9 +2259,14 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
if (auto *C = dyn_cast<Constant>(V)) {
NewV = ConstantExpr::getBitCast(C, DestTy);
} else if (auto *LI = dyn_cast<LoadInst>(V)) {
Builder.SetInsertPoint(LI->getNextNode());
NewV = Builder.CreateBitCast(LI, DestTy);
Worklist.Add(LI);
// Explicitly perform load combine to make sure no opposing transform
// can remove the bitcast in the meantime and trigger an infinite loop.
Builder.SetInsertPoint(LI);
NewV = combineLoadToNewType(*LI, DestTy);
// Remove the old load and its use in the old phi, which itself becomes
// dead once the whole transform finishes.
replaceInstUsesWith(*LI, UndefValue::get(LI->getType()));
eraseInstFromFunction(*LI);
} else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
NewV = BCI->getOperand(0);
} else if (auto *PrevPN = dyn_cast<PHINode>(V)) {
Expand All @@ -2259,26 +2289,33 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
Instruction *RetVal = nullptr;
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
for (User *V : OldPN->users()) {
for (auto It = OldPN->user_begin(), End = OldPN->user_end(); It != End; ) {
User *V = *It;
// We may remove this user, advance to avoid iterator invalidation.
++It;
if (auto *SI = dyn_cast<StoreInst>(V)) {
if (SI->isSimple() && SI->getOperand(0) == OldPN) {
Builder.SetInsertPoint(SI);
auto *NewBC =
cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
SI->setOperand(0, NewBC);
Worklist.Add(SI);
assert(hasStoreUsersOnly(*NewBC));
}
assert(SI->isSimple() && SI->getOperand(0) == OldPN);
Builder.SetInsertPoint(SI);
auto *NewBC =
cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
SI->setOperand(0, NewBC);
Worklist.Add(SI);
assert(hasStoreUsersOnly(*NewBC));
}
else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
// Verify it's a B->A cast.
Type *TyB = BCI->getOperand(0)->getType();
Type *TyA = BCI->getType();
if (TyA == DestTy && TyB == SrcTy) {
Instruction *I = replaceInstUsesWith(*BCI, NewPN);
if (BCI == &CI)
RetVal = I;
}
assert(TyA == DestTy && TyB == SrcTy);
(void) TyA;
(void) TyB;
Instruction *I = replaceInstUsesWith(*BCI, NewPN);
if (BCI == &CI)
RetVal = I;
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
assert(OldPhiNodes.count(PHI) > 0);
(void) PHI;
} else {
llvm_unreachable("all uses should be handled");
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner
/// \return true if successful.
bool replacePointer(Instruction &I, Value *V);

LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy,
const Twine &Suffix = "");

private:
bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
bool shouldChangeType(Type *From, Type *To) const;
Expand Down
27 changes: 13 additions & 14 deletions llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ static bool isSupportedAtomicType(Type *Ty) {
///
/// Note that this will create all of the instructions with whatever insert
/// point the \c InstCombiner currently is using.
static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy,
const Twine &Suffix = "") {
LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy,
const Twine &Suffix) {
assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) &&
"can't fold an atomic load to requested type");

Expand All @@ -462,9 +462,9 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT
if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) &&
NewPtr->getType()->getPointerElementType() == NewTy &&
NewPtr->getType()->getPointerAddressSpace() == AS))
NewPtr = IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS));
NewPtr = Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS));

LoadInst *NewLoad = IC.Builder.CreateAlignedLoad(
LoadInst *NewLoad = Builder.CreateAlignedLoad(
NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix);
NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID());
MDBuilder MDB(NewLoad->getContext());
Expand Down Expand Up @@ -505,7 +505,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT
NewLoad->setMetadata(ID, N);
break;
case LLVMContext::MD_range:
copyRangeMetadata(IC.getDataLayout(), LI, N, *NewLoad);
copyRangeMetadata(getDataLayout(), LI, N, *NewLoad);
break;
}
}
Expand Down Expand Up @@ -639,9 +639,8 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
return SI && SI->getPointerOperand() != &LI &&
!SI->getPointerOperand()->isSwiftError();
})) {
LoadInst *NewLoad = combineLoadToNewType(
IC, LI,
Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty)));
LoadInst *NewLoad = IC.combineLoadToNewType(
LI, Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty)));
// Replace all the stores with stores of the newly loaded value.
for (auto UI = LI.user_begin(), UE = LI.user_end(); UI != UE;) {
auto *SI = cast<StoreInst>(*UI++);
Expand All @@ -663,7 +662,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
if (auto* CI = dyn_cast<CastInst>(LI.user_back()))
if (CI->isNoopCast(DL))
if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) {
LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy());
LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy());
CI->replaceAllUsesWith(NewLoad);
IC.eraseInstFromFunction(*CI);
return &LI;
Expand Down Expand Up @@ -691,8 +690,8 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
// If the struct only have one element, we unpack.
auto NumElements = ST->getNumElements();
if (NumElements == 1) {
LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U),
".unpack");
LoadInst *NewLoad = IC.combineLoadToNewType(LI, ST->getTypeAtIndex(0U),
".unpack");
AAMDNodes AAMD;
LI.getAAMetadata(AAMD);
NewLoad->setAAMetadata(AAMD);
Expand Down Expand Up @@ -741,7 +740,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
auto *ET = AT->getElementType();
auto NumElements = AT->getNumElements();
if (NumElements == 1) {
LoadInst *NewLoad = combineLoadToNewType(IC, LI, ET, ".unpack");
LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack");
AAMDNodes AAMD;
LI.getAAMetadata(AAMD);
NewLoad->setAAMetadata(AAMD);
Expand Down Expand Up @@ -1377,8 +1376,8 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
return false;

IC.Builder.SetInsertPoint(LI);
LoadInst *NewLI = combineLoadToNewType(
IC, *LI, LoadAddr->getType()->getPointerElementType());
LoadInst *NewLI = IC.combineLoadToNewType(
*LI, LoadAddr->getType()->getPointerElementType());
// Replace all the stores with stores of the newly loaded value.
for (auto *UI : LI->users()) {
auto *USI = cast<StoreInst>(UI);
Expand Down
18 changes: 17 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ STATISTIC(NumReassoc , "Number of reassociations");
DEBUG_COUNTER(VisitCounter, "instcombine-visit",
"Controls which instructions are visited");

static constexpr unsigned InstCombineDefaultMaxIterations = 1000;
static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000;

static cl::opt<bool>
EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"),
cl::init(true));
Expand All @@ -129,6 +132,12 @@ static cl::opt<bool>
EnableExpensiveCombines("expensive-combines",
cl::desc("Enable expensive instruction combines"));

static cl::opt<unsigned> InfiniteLoopDetectionThreshold(
"instcombine-infinite-loop-threshold",
cl::desc("Number of instruction combining iterations considered an "
"infinite loop"),
cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden);

static cl::opt<unsigned>
MaxArraySize("instcombine-maxarray-size", cl::init(1024),
cl::desc("Maximum array size considered when doing a combine"));
Expand Down Expand Up @@ -3508,9 +3517,16 @@ static bool combineInstructionsOverFunction(
MadeIRChange = LowerDbgDeclare(F);

// Iterate while there is work to do.
int Iteration = 0;
unsigned Iteration = 0;
while (true) {
++Iteration;

if (Iteration > InfiniteLoopDetectionThreshold) {
report_fatal_error(
"Instruction Combining seems stuck in an infinite loop after " +
Twine(InfiniteLoopDetectionThreshold) + " iterations.");
}

LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on "
<< F.getName() << "\n");

Expand Down
33 changes: 33 additions & 0 deletions llvm/test/Transforms/InstCombine/bitcast-phi-uselistorder.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -instcombine < %s | FileCheck %s

@Q = internal unnamed_addr global double 1.000000e+00, align 8

define double @test(i1 %c, i64* %p) {
; CHECK-LABEL: @test(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[END:%.*]]
; CHECK: if:
; CHECK-NEXT: [[LOAD1:%.*]] = load double, double* @Q, align 8
; CHECK-NEXT: br label [[END]]
; CHECK: end:
; CHECK-NEXT: [[TMP0:%.*]] = phi double [ 0.000000e+00, [[ENTRY:%.*]] ], [ [[LOAD1]], [[IF]] ]
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i64* [[P:%.*]] to double*
; CHECK-NEXT: store double [[TMP0]], double* [[TMP1]], align 8
; CHECK-NEXT: ret double [[TMP0]]
;
entry:
br i1 %c, label %if, label %end

if:
%load = load i64, i64* bitcast (double* @Q to i64*), align 8
br label %end

end:
%phi = phi i64 [ 0, %entry ], [ %load, %if ]
store i64 %phi, i64* %p, align 8
%cast = bitcast i64 %phi to double
ret double %cast

uselistorder i64 %phi, { 1, 0 }
}
Loading