[mlir][canonicalize] Add filter-dialects option#193041
[mlir][canonicalize] Add filter-dialects option#193041
Conversation
Add a new `filter-dialects` list option to the canonicalize pass. When provided, only canonicalization patterns from the listed dialects are collected, and the named dialects are force-loaded via getDependentDialects. Loading flow: the Canonicalizer's getDependentDialects override calls registry.insert(StringRef) for each filter-dialect name, which records the name in a new `nameOnlyDependencies` list on DialectRegistry. The PassManager's existing pipeline-init loop then iterates getDialectNames() — which now surfaces those names — and calls context->getOrLoadDialect(name) on each one; the real allocator is resolved from the context's own registry (registered by the tool) and the dialect is loaded before multi-threaded execution begins. DialectRegistry changes: - New `insert(StringRef)` overload for name-only dependencies. - getDialectNames() now returns a SmallVector<StringRef> that includes both real entries and name-only dependencies. - appendTo/isSubsetOf updated to carry name-only entries through. Assisted-by: Claude Code
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Mehdi Amini (joker-eph) ChangesAdd a new Loading flow: the Canonicalizer's getDependentDialects override calls registry.insert(StringRef) for each filter-dialect name, which records the name in a new DialectRegistry changes:
Assisted-by: Claude Code Full diff: https://github.com/llvm/llvm-project/pull/193041.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index b7d3e5d67e6d7..6e8d0c8609635 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -17,8 +17,10 @@
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallVector.h"
#include <map>
+#include <string>
#include <tuple>
namespace mlir {
@@ -172,6 +174,11 @@ class DialectRegistry {
void insert(TypeID typeID, StringRef name,
const DialectAllocatorFunction &ctor);
+ /// Add a dialect dependency by name only.
+ /// This is useful for passes that learn about required dialects from
+ /// string-valued options.
+ void insert(StringRef name);
+
/// Add a new dynamic dialect constructor in the registry. The constructor
/// provides as argument the created dynamic dialect, and is expected to
/// register the dialect types, attributes, and ops, using the
@@ -190,17 +197,25 @@ class DialectRegistry {
destination.insert(nameAndRegistrationIt.second.first,
nameAndRegistrationIt.first,
nameAndRegistrationIt.second.second);
+ for (const std::string &name : nameOnlyDependencies)
+ destination.insert(StringRef(name));
// Merge the extensions.
for (const auto &extension : extensions)
destination.extensions.try_emplace(extension.first,
extension.second->clone());
}
- /// Return the names of dialects known to this registry.
- auto getDialectNames() const {
- return llvm::map_range(
- registry,
- [](const MapTy::value_type &item) -> StringRef { return item.first; });
+ /// Return the names of dialects known to this registry. This includes both
+ /// dialects with registered allocators and name-only dependencies added via
+ /// `insert(StringRef)`.
+ SmallVector<StringRef> getDialectNames() const {
+ SmallVector<StringRef> names;
+ names.reserve(registry.size() + nameOnlyDependencies.size());
+ for (const auto &item : registry)
+ names.push_back(item.first);
+ for (const std::string &name : nameOnlyDependencies)
+ names.push_back(name);
+ return names;
}
/// Apply any held extensions that require the given dialect. Users are not
@@ -261,6 +276,10 @@ class DialectRegistry {
private:
MapTy registry;
+ /// Name-only dialect dependencies. The allocator for these names is
+ /// expected to be resolved from the MLIRContext's own registry when the
+ /// dependency is loaded (e.g. via MLIRContext::getOrLoadDialect).
+ SmallVector<std::string> nameOnlyDependencies;
llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
};
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1474e580cfc03..aaa6ff99a5e99 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -48,7 +48,11 @@ def CanonicalizerPass : Pass<"canonicalize"> {
Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1",
"Max. number of pattern rewrites within an iteration">,
Option<"testConvergence", "test-convergence", "bool", /*default=*/"false",
- "Test only: Fail pass on non-convergence to detect cyclic pattern">
+ "Test only: Fail pass on non-convergence to detect cyclic pattern">,
+ ListOption<"filterDialects", "filter-dialects", "std::string",
+ "If non-empty, only collect canonicalization patterns from the"
+ " dialects with the given namespaces. The listed dialects are"
+ " force-loaded into the context as dependent dialects.">
] # RewritePassUtils.options;
}
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 952619b4477a7..669cc80cf8f34 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -225,6 +225,16 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
}
}
+void DialectRegistry::insert(StringRef name) {
+ // If we already have an allocator for this name, nothing to do: the existing
+ // registration will take care of loading the dialect.
+ if (registry.count(name))
+ return;
+ if (llvm::is_contained(nameOnlyDependencies, name))
+ return;
+ nameOnlyDependencies.emplace_back(name);
+}
+
void DialectRegistry::insertDynamic(
StringRef name, const DynamicDialectPopulationFunction &ctor) {
// This TypeID marks dynamic dialects. We cannot give a TypeID for the
@@ -326,6 +336,14 @@ bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
return false;
// Check that the current dialects fully overlap with the dialects in 'rhs'.
- return llvm::all_of(
- registry, [&](const auto &it) { return rhs.registry.count(it.first); });
+ if (!llvm::all_of(registry, [&](const auto &it) {
+ return rhs.registry.count(it.first);
+ }))
+ return false;
+
+ // Check that all name-only dependencies are known in 'rhs'.
+ return llvm::all_of(nameOnlyDependencies, [&](const std::string &name) {
+ return rhs.registry.count(name) ||
+ llvm::is_contained(rhs.nameOnlyDependencies, name);
+ });
}
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 9f9bad1c2a678..8642facf73a06 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,8 +13,10 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/DialectRegistry.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
namespace mlir {
#define GEN_PASS_DEF_CANONICALIZERPASS
@@ -39,6 +41,13 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
this->enabledPatterns = enabledPatterns;
}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ // Force-load any dialects named via the `filter-dialects` option. The
+ // allocator is resolved later from the MLIRContext's own registry.
+ for (const std::string &name : filterDialects)
+ registry.insert(StringRef(name));
+ }
+
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
@@ -48,11 +57,28 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
config.setMaxIterations(maxIterations);
config.setMaxNumRewrites(maxNumRewrites);
+ llvm::DenseSet<TypeID> allowedDialects;
+ for (const std::string &name : filterDialects) {
+ Dialect *dialect = context->getLoadedDialect(name);
+ if (!dialect) {
+ return emitError(UnknownLoc::get(context))
+ << "canonicalize filter-dialects: dialect '" << name
+ << "' is not loaded in the context";
+ }
+ allowedDialects.insert(dialect->getTypeID());
+ }
+ auto isAllowed = [&](Dialect *dialect) {
+ return allowedDialects.empty() ||
+ allowedDialects.contains(dialect->getTypeID());
+ };
+
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
- dialect->getCanonicalizationPatterns(owningPatterns);
+ if (isAllowed(dialect))
+ dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
- op.getCanonicalizationPatterns(owningPatterns, context);
+ if (isAllowed(&op.getDialect()))
+ op.getCanonicalizationPatterns(owningPatterns, context);
patterns = std::make_shared<FrozenRewritePatternSet>(
std::move(owningPatterns), disabledPatterns, enabledPatterns);
diff --git a/mlir/test/Transforms/canonicalize-filter-dialects.mlir b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
new file mode 100644
index 0000000000000..a6b9a38339f6d
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=arith}))' | FileCheck %s --check-prefix=ARITH
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=func}))' | FileCheck %s --check-prefix=FUNC
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s --check-prefix=ALL
+// RUN: not mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=does_not_exist}))' 2>&1 | FileCheck %s --check-prefix=ERR
+
+// The `SubIRHSAddConstant` arith pattern rewrites `subi(addi(x, c0), c1)` into
+// `addi(x, c0 - c1)`. The pattern only fires when arith canonicalizations are
+// loaded.
+
+// ARITH-LABEL: func @pattern_test
+// ARITH-NOT: arith.subi
+// ARITH: arith.addi %{{.*}}, %[[C:.*]]
+
+// FUNC-LABEL: func @pattern_test
+// FUNC: arith.addi
+// FUNC: arith.subi
+
+// ALL-LABEL: func @pattern_test
+// ALL-NOT: arith.subi
+// ALL: arith.addi %{{.*}}, %[[C:.*]]
+
+// ERR: canonicalize filter-dialects: dialect 'does_not_exist' is not loaded in the context
+func.func @pattern_test(%a: i32) -> i32 {
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2 : i32
+ %add = arith.addi %a, %c1 : i32
+ %sub = arith.subi %add, %c2 : i32
+ return %sub : i32
+}
|
Add a new
filter-dialectslist option to the canonicalize pass. When provided, only canonicalization patterns from the listed dialects are collected, and the named dialects are force-loaded via getDependentDialects.Loading flow: the Canonicalizer's getDependentDialects override calls registry.insert(StringRef) for each filter-dialect name, which records the name in a new
nameOnlyDependencieslist on DialectRegistry. The PassManager's existing pipeline-init loop then iterates getDialectNames() — which now surfaces those names — and calls context->getOrLoadDialect(name) on each one; the real allocator is resolved from the context's own registry (registered by the tool) and the dialect is loaded before multi-threaded execution begins.DialectRegistry changes:
insert(StringRef)overload for name-only dependencies.Assisted-by: Claude Code