From 203e3e766b3b8ed2760a1fae13b5766291f80054 Mon Sep 17 00:00:00 2001 From: James Molloy Date: Sat, 12 Oct 2024 12:14:53 -0700 Subject: [PATCH] [xls][mlir] Speed up lower_counted_for On large modules this pass could dominate due to SymbolTable lookups. Before: 2m05s. After 2s. PiperOrigin-RevId: 685238926 --- .../mlir/testdata/lower_counted_for.mlir | 8 ++--- .../mlir/transforms/lower_counted_for.cc | 36 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/xls/contrib/mlir/testdata/lower_counted_for.mlir b/xls/contrib/mlir/testdata/lower_counted_for.mlir index 16f5b9da96..177ea562e3 100644 --- a/xls/contrib/mlir/testdata/lower_counted_for.mlir +++ b/xls/contrib/mlir/testdata/lower_counted_for.mlir @@ -65,7 +65,7 @@ func.func @reduce_arity_2(%arg0: i32, %arg1: i32) -> i32 attributes {xls = true} return %0#0 : i32 } -// CHECK-LABEL: func.func private @for_body_3( +// CHECK-LABEL: func.func private @for_body_1( // CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> i32 // CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_0]] : i32 to index // CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : i32 @@ -74,10 +74,10 @@ func.func @reduce_arity_2(%arg0: i32, %arg1: i32) -> i32 attributes {xls = true} // CHECK-LABEL: func.func private @for_body_2( // CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32) -> i32 // CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_0]] : i32 to index -// CHECK: %[[VAL_4:.*]] = "xls.counted_for"(%[[VAL_1]], %[[VAL_1]], %[[VAL_2]]) <{stride = 1 : i64, to_apply = @for_body_3, trip_count = 1024 : i64}> : (i32, i32, i32) -> i32 +// CHECK: %[[VAL_4:.*]] = "xls.counted_for"(%[[VAL_1]], %[[VAL_1]], %[[VAL_2]]) <{stride = 1 : i64, to_apply = @for_body_1, trip_count = 1024 : i64}> : (i32, i32, i32) -> i32 // CHECK: return %[[VAL_4]] : i32 -// CHECK-LABEL: func.func private @for_body_1( +// CHECK-LABEL: func.func private @for_body_3( // CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32) -> i32 // CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_0]] : i32 to index // CHECK: %[[VAL_4:.*]] = "xls.counted_for"(%[[VAL_1]], %[[VAL_2]]) <{stride = 1 : i64, to_apply = @for_body_2, trip_count = 1024 : i64}> : (i32, i32) -> i32 @@ -89,7 +89,7 @@ func.func @reduce_arity_2(%arg0: i32, %arg1: i32) -> i32 attributes {xls = true} // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1024 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_5:.*]] = "xls.counted_for"(%[[VAL_4]], %[[VAL_0]]) <{stride = 1 : i64, to_apply = @for_body_1, trip_count = 1024 : i64}> : (i32, i32) -> i32 +// CHECK: %[[VAL_5:.*]] = "xls.counted_for"(%[[VAL_4]], %[[VAL_0]]) <{stride = 1 : i64, to_apply = @for_body_3, trip_count = 1024 : i64}> : (i32, i32) -> i32 // CHECK: return %[[VAL_5]] : i32 func.func @triple_nest(%arg0: i32) -> i32 attributes {xls = true} { %c0 = arith.constant 0 : index diff --git a/xls/contrib/mlir/transforms/lower_counted_for.cc b/xls/contrib/mlir/transforms/lower_counted_for.cc index d1cab53a4f..d09efd032d 100644 --- a/xls/contrib/mlir/transforms/lower_counted_for.cc +++ b/xls/contrib/mlir/transforms/lower_counted_for.cc @@ -48,9 +48,8 @@ namespace mlir::xls { #define GEN_PASS_DEF_LOWERCOUNTEDFORPASS #include "xls/contrib/mlir/transforms/passes.h.inc" // IWYU pragma: keep -using ::llvm::SmallVector; - namespace { +using ::llvm::SmallVector; namespace fixed { // TODO(jmolloy): This is a copy of the one in SCF utils. But that version is @@ -229,24 +228,25 @@ class TuplifyRewrite : public OpConversionPattern { } }; -std::string createUniqueName(Operation *op, std::string prefix) { - // TODO(jpienaar): This could be made more efficient. Current approach does - // work that could be cached and reused. - mlir::Operation *symbolTableOp = - op->getParentWithTrait(); - if (mlir::SymbolTable::lookupSymbolIn(symbolTableOp, prefix) == nullptr) { - return prefix; +StringAttr createUniqueName(MLIRContext &context, SymbolTable &symbolTable, + DenseSet &addedSymbols, + StringRef prefix) { + if (symbolTable.lookup(prefix) == nullptr && !addedSymbols.contains(prefix)) { + addedSymbols.insert(prefix); + return StringAttr::get(&context, prefix); } unsigned uniquingCounter = 0; llvm::SmallString<128> name = SymbolTable::generateSymbolName<128>( prefix, [&](llvm::StringRef candidate) { - return mlir::SymbolTable::lookupSymbolIn(symbolTableOp, candidate) != - nullptr; + return symbolTable.lookup(candidate) || + addedSymbols.contains(candidate.str()); }, uniquingCounter); - return std::string(name.str()); + auto result = StringAttr::get(&context, name); + addedSymbols.insert(result); + return result; } class ForToCountedForRewrite : public OpConversionPattern { @@ -267,9 +267,7 @@ class ForToCountedForRewrite : public OpConversionPattern { // not be updated until after rewrites have completed (meaning // createUniqueName would always return the same value in the same rewrite // cycle causing clashes). - std::string preferredName = - cast(op->getAttr(kPreferredNameAttr)).str(); - std::string name = createUniqueName(op, preferredName); + std::string name = cast(op->getAttr(kPreferredNameAttr)).str(); mlir::func::CallOp callOp; auto func = fixed::outlineSingleBlockRegion( @@ -293,10 +291,12 @@ class LowerCountedForPass private: void runOnOperation() override { // See comment in ForToCountedForRewrite for why we do this. + DenseSet addedSymbols; + SymbolTable symbolTable(getOperation()); getOperation().walk([&](ForOp op) { - op->setAttr( - kPreferredNameAttr, - StringAttr::get(op->getContext(), createUniqueName(op, "for_body"))); + op->setAttr(kPreferredNameAttr, + createUniqueName(getContext(), symbolTable, addedSymbols, + "for_body")); }); ConversionTarget target(getContext());