1414#include " mlir/Transforms/Passes.h"
1515
1616#include " mlir/IR/SymbolTable.h"
17+ #include " llvm/Support/Debug.h"
1718
1819namespace mlir {
1920#define GEN_PASS_DEF_SYMBOLDCE
@@ -22,6 +23,8 @@ namespace mlir {
2223
2324using namespace mlir ;
2425
26+ #define DEBUG_TYPE " symbol-dce"
27+
2528namespace {
2629struct SymbolDCE : public impl ::SymbolDCEBase<SymbolDCE> {
2730 void runOnOperation () override ;
@@ -84,6 +87,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
8487 SymbolTableCollection &symbolTable,
8588 bool symbolTableIsHidden,
8689 DenseSet<Operation *> &liveSymbols) {
90+ LLVM_DEBUG (llvm::dbgs () << " computeLiveness: " << symbolTableOp->getName ()
91+ << " \n " );
8792 // A worklist of live operations to propagate uses from.
8893 SmallVector<Operation *, 16 > worklist;
8994
@@ -105,36 +110,69 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
105110 }
106111
107112 // Process the set of symbols that were known to be live, adding new symbols
108- // that are referenced within.
113+ // that are referenced within. For operations that are not symbol tables, it
114+ // considers the liveness with respect to the op itself rather than scope of
115+ // nested symbol tables by enqueuing all the top level operations for
116+ // consideration.
109117 while (!worklist.empty ()) {
110118 Operation *op = worklist.pop_back_val ();
119+ LLVM_DEBUG (llvm::dbgs () << " processing: " << op->getName () << " \n " );
111120
112121 // If this is a symbol table, recursively compute its liveness.
113122 if (op->hasTrait <OpTrait::SymbolTable>()) {
114123 // The internal symbol table is hidden if the parent is, if its not a
115124 // symbol, or if it is a private symbol.
116125 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117126 bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate ();
127+ LLVM_DEBUG (llvm::dbgs () << " \t symbol table: " << op->getName ()
128+ << " is hidden: " << symIsHidden << " \n " );
118129 if (failed (computeLiveness (op, symbolTable, symIsHidden, liveSymbols)))
119130 return failure ();
131+ } else {
132+ LLVM_DEBUG (llvm::dbgs ()
133+ << " \t non-symbol table: " << op->getName () << " \n " );
134+ // If the op is not a symbol table, then, unless op itself is dead which
135+ // would be handled by DCE, we need to check all the regions and blocks
136+ // within the op to find the uses (e.g., consider visibility within op as
137+ // if top level rather than relying on pure symbol table visibility). This
138+ // is more conservative than SymbolTable::walkSymbolTables in the case
139+ // where there is again SymbolTable information to take advantage of.
140+ for (auto ®ion : op->getRegions ())
141+ for (auto &block : region.getBlocks ())
142+ for (Operation &op : block)
143+ worklist.push_back (&op);
120144 }
121145
146+ // Get the first parent symbol table op. Note: due to enqueueing of
147+ // top-level ops, we may not have a symbol table parent here, but if we do
148+ // not, then we also don't have a symbol.
149+ Operation *parentOp = op->getParentOp ();
150+ if (!parentOp->hasTrait <OpTrait::SymbolTable>())
151+ continue ;
152+
122153 // Collect the uses held by this operation.
123154 std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses (op);
124155 if (!uses) {
125156 return op->emitError ()
126- << " operation contains potentially unknown symbol table, "
127- " meaning that we can't reliable compute symbol uses" ;
157+ << " operation contains potentially unknown symbol table, meaning "
158+ << " that we can't reliable compute symbol uses" ;
128159 }
129160
130161 SmallVector<Operation *, 4 > resolvedSymbols;
162+ LLVM_DEBUG (llvm::dbgs () << " uses of " << op->getName () << " \n " );
131163 for (const SymbolTable::SymbolUse &use : *uses) {
164+ LLVM_DEBUG (llvm::dbgs () << " \t use: " << use.getUser () << " \n " );
132165 // Lookup the symbols referenced by this use.
133166 resolvedSymbols.clear ();
134- if (failed (symbolTable.lookupSymbolIn (
135- op-> getParentOp (), use. getSymbolRef (), resolvedSymbols)))
167+ if (failed (symbolTable.lookupSymbolIn (parentOp, use. getSymbolRef (),
168+ resolvedSymbols)))
136169 // Ignore references to unknown symbols.
137170 continue ;
171+ LLVM_DEBUG ({
172+ llvm::dbgs () << " \t\t resolved symbols: " ;
173+ llvm::interleaveComma (resolvedSymbols, llvm::dbgs ());
174+ llvm::dbgs () << " \n " ;
175+ });
138176
139177 // Mark each of the resolved symbols as live.
140178 for (Operation *resolvedSymbol : resolvedSymbols)
0 commit comments