Skip to content

[mlir][canonicalize] Add filter-dialects option#193041

Open
joker-eph wants to merge 1 commit intollvm:mainfrom
joker-eph:canonicalize-dialects
Open

[mlir][canonicalize] Add filter-dialects option#193041
joker-eph wants to merge 1 commit intollvm:mainfrom
joker-eph:canonicalize-dialects

Conversation

@joker-eph
Copy link
Copy Markdown
Collaborator

Add a new filter-dialects list option to the canonicalize pass. When provided, only canonicalization patterns from the listed dialects are collected, and the named dialects are force-loaded via getDependentDialects.

Loading flow: the Canonicalizer's getDependentDialects override calls registry.insert(StringRef) for each filter-dialect name, which records the name in a new nameOnlyDependencies list on DialectRegistry. The PassManager's existing pipeline-init loop then iterates getDialectNames() — which now surfaces those names — and calls context->getOrLoadDialect(name) on each one; the real allocator is resolved from the context's own registry (registered by the tool) and the dialect is loaded before multi-threaded execution begins.

DialectRegistry changes:

  • New insert(StringRef) overload for name-only dependencies.
  • getDialectNames() now returns a SmallVector that includes both real entries and name-only dependencies.
  • appendTo/isSubsetOf updated to carry name-only entries through.

Assisted-by: Claude Code

Add a new `filter-dialects` list option to the canonicalize pass. When
provided, only canonicalization patterns from the listed dialects are
collected, and the named dialects are force-loaded via
getDependentDialects.

Loading flow: the Canonicalizer's getDependentDialects override calls
registry.insert(StringRef) for each filter-dialect name, which records
the name in a new `nameOnlyDependencies` list on DialectRegistry. The
PassManager's existing pipeline-init loop then iterates
getDialectNames() — which now surfaces those names — and calls
context->getOrLoadDialect(name) on each one; the real allocator is
resolved from the context's own registry (registered by the tool) and
the dialect is loaded before multi-threaded execution begins.

DialectRegistry changes:
- New `insert(StringRef)` overload for name-only dependencies.
- getDialectNames() now returns a SmallVector<StringRef> that includes
  both real entries and name-only dependencies.
- appendTo/isSubsetOf updated to carry name-only entries through.

Assisted-by: Claude Code
@joker-eph joker-eph requested a review from ftynse April 20, 2026 17:51
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Apr 20, 2026
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 20, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

Add a new filter-dialects list option to the canonicalize pass. When provided, only canonicalization patterns from the listed dialects are collected, and the named dialects are force-loaded via getDependentDialects.

Loading flow: the Canonicalizer's getDependentDialects override calls registry.insert(StringRef) for each filter-dialect name, which records the name in a new nameOnlyDependencies list on DialectRegistry. The PassManager's existing pipeline-init loop then iterates getDialectNames() — which now surfaces those names — and calls context->getOrLoadDialect(name) on each one; the real allocator is resolved from the context's own registry (registered by the tool) and the dialect is loaded before multi-threaded execution begins.

DialectRegistry changes:

  • New insert(StringRef) overload for name-only dependencies.
  • getDialectNames() now returns a SmallVector<StringRef> that includes both real entries and name-only dependencies.
  • appendTo/isSubsetOf updated to carry name-only entries through.

Assisted-by: Claude Code


Full diff: https://github.com/llvm/llvm-project/pull/193041.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/DialectRegistry.h (+24-5)
  • (modified) mlir/include/mlir/Transforms/Passes.td (+5-1)
  • (modified) mlir/lib/IR/Dialect.cpp (+20-2)
  • (modified) mlir/lib/Transforms/Canonicalizer.cpp (+28-2)
  • (added) mlir/test/Transforms/canonicalize-filter-dialects.mlir (+29)
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index b7d3e5d67e6d7..6e8d0c8609635 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -17,8 +17,10 @@
 #include "mlir/Support/TypeID.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallVector.h"
 
 #include <map>
+#include <string>
 #include <tuple>
 
 namespace mlir {
@@ -172,6 +174,11 @@ class DialectRegistry {
   void insert(TypeID typeID, StringRef name,
               const DialectAllocatorFunction &ctor);
 
+  /// Add a dialect dependency by name only.
+  /// This is useful for passes that learn about required dialects from
+  /// string-valued options.
+  void insert(StringRef name);
+
   /// Add a new dynamic dialect constructor in the registry. The constructor
   /// provides as argument the created dynamic dialect, and is expected to
   /// register the dialect types, attributes, and ops, using the
@@ -190,17 +197,25 @@ class DialectRegistry {
       destination.insert(nameAndRegistrationIt.second.first,
                          nameAndRegistrationIt.first,
                          nameAndRegistrationIt.second.second);
+    for (const std::string &name : nameOnlyDependencies)
+      destination.insert(StringRef(name));
     // Merge the extensions.
     for (const auto &extension : extensions)
       destination.extensions.try_emplace(extension.first,
                                          extension.second->clone());
   }
 
-  /// Return the names of dialects known to this registry.
-  auto getDialectNames() const {
-    return llvm::map_range(
-        registry,
-        [](const MapTy::value_type &item) -> StringRef { return item.first; });
+  /// Return the names of dialects known to this registry. This includes both
+  /// dialects with registered allocators and name-only dependencies added via
+  /// `insert(StringRef)`.
+  SmallVector<StringRef> getDialectNames() const {
+    SmallVector<StringRef> names;
+    names.reserve(registry.size() + nameOnlyDependencies.size());
+    for (const auto &item : registry)
+      names.push_back(item.first);
+    for (const std::string &name : nameOnlyDependencies)
+      names.push_back(name);
+    return names;
   }
 
   /// Apply any held extensions that require the given dialect. Users are not
@@ -261,6 +276,10 @@ class DialectRegistry {
 
 private:
   MapTy registry;
+  /// Name-only dialect dependencies. The allocator for these names is
+  /// expected to be resolved from the MLIRContext's own registry when the
+  /// dependency is loaded (e.g. via MLIRContext::getOrLoadDialect).
+  SmallVector<std::string> nameOnlyDependencies;
   llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
 };
 
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1474e580cfc03..aaa6ff99a5e99 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -48,7 +48,11 @@ def CanonicalizerPass : Pass<"canonicalize"> {
     Option<"maxNumRewrites", "max-num-rewrites", "int64_t", /*default=*/"-1",
            "Max. number of pattern rewrites within an iteration">,
     Option<"testConvergence", "test-convergence", "bool", /*default=*/"false",
-           "Test only: Fail pass on non-convergence to detect cyclic pattern">
+           "Test only: Fail pass on non-convergence to detect cyclic pattern">,
+    ListOption<"filterDialects", "filter-dialects", "std::string",
+               "If non-empty, only collect canonicalization patterns from the"
+               " dialects with the given namespaces. The listed dialects are"
+               " force-loaded into the context as dependent dialects.">
   ] # RewritePassUtils.options;
 }
 
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 952619b4477a7..669cc80cf8f34 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -225,6 +225,16 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
   }
 }
 
+void DialectRegistry::insert(StringRef name) {
+  // If we already have an allocator for this name, nothing to do: the existing
+  // registration will take care of loading the dialect.
+  if (registry.count(name))
+    return;
+  if (llvm::is_contained(nameOnlyDependencies, name))
+    return;
+  nameOnlyDependencies.emplace_back(name);
+}
+
 void DialectRegistry::insertDynamic(
     StringRef name, const DynamicDialectPopulationFunction &ctor) {
   // This TypeID marks dynamic dialects. We cannot give a TypeID for the
@@ -326,6 +336,14 @@ bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
     return false;
 
   // Check that the current dialects fully overlap with the dialects in 'rhs'.
-  return llvm::all_of(
-      registry, [&](const auto &it) { return rhs.registry.count(it.first); });
+  if (!llvm::all_of(registry, [&](const auto &it) {
+        return rhs.registry.count(it.first);
+      }))
+    return false;
+
+  // Check that all name-only dependencies are known in 'rhs'.
+  return llvm::all_of(nameOnlyDependencies, [&](const std::string &name) {
+    return rhs.registry.count(name) ||
+           llvm::is_contained(rhs.nameOnlyDependencies, name);
+  });
 }
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 9f9bad1c2a678..8642facf73a06 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,8 +13,10 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/IR/DialectRegistry.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseSet.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CANONICALIZERPASS
@@ -39,6 +41,13 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
     this->enabledPatterns = enabledPatterns;
   }
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    // Force-load any dialects named via the `filter-dialects` option. The
+    // allocator is resolved later from the MLIRContext's own registry.
+    for (const std::string &name : filterDialects)
+      registry.insert(StringRef(name));
+  }
+
   /// Initialize the canonicalizer by building the set of patterns used during
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
@@ -48,11 +57,28 @@ struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
     config.setMaxIterations(maxIterations);
     config.setMaxNumRewrites(maxNumRewrites);
 
+    llvm::DenseSet<TypeID> allowedDialects;
+    for (const std::string &name : filterDialects) {
+      Dialect *dialect = context->getLoadedDialect(name);
+      if (!dialect) {
+        return emitError(UnknownLoc::get(context))
+               << "canonicalize filter-dialects: dialect '" << name
+               << "' is not loaded in the context";
+      }
+      allowedDialects.insert(dialect->getTypeID());
+    }
+    auto isAllowed = [&](Dialect *dialect) {
+      return allowedDialects.empty() ||
+             allowedDialects.contains(dialect->getTypeID());
+    };
+
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
-      dialect->getCanonicalizationPatterns(owningPatterns);
+      if (isAllowed(dialect))
+        dialect->getCanonicalizationPatterns(owningPatterns);
     for (RegisteredOperationName op : context->getRegisteredOperations())
-      op.getCanonicalizationPatterns(owningPatterns, context);
+      if (isAllowed(&op.getDialect()))
+        op.getCanonicalizationPatterns(owningPatterns, context);
 
     patterns = std::make_shared<FrozenRewritePatternSet>(
         std::move(owningPatterns), disabledPatterns, enabledPatterns);
diff --git a/mlir/test/Transforms/canonicalize-filter-dialects.mlir b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
new file mode 100644
index 0000000000000..a6b9a38339f6d
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-filter-dialects.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=arith}))' | FileCheck %s --check-prefix=ARITH
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=func}))' | FileCheck %s --check-prefix=FUNC
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s --check-prefix=ALL
+// RUN: not mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{filter-dialects=does_not_exist}))' 2>&1 | FileCheck %s --check-prefix=ERR
+
+// The `SubIRHSAddConstant` arith pattern rewrites `subi(addi(x, c0), c1)` into
+// `addi(x, c0 - c1)`. The pattern only fires when arith canonicalizations are
+// loaded.
+
+// ARITH-LABEL: func @pattern_test
+// ARITH-NOT:     arith.subi
+// ARITH:         arith.addi %{{.*}}, %[[C:.*]]
+
+// FUNC-LABEL: func @pattern_test
+// FUNC:         arith.addi
+// FUNC:         arith.subi
+
+// ALL-LABEL: func @pattern_test
+// ALL-NOT:     arith.subi
+// ALL:         arith.addi %{{.*}}, %[[C:.*]]
+
+// ERR: canonicalize filter-dialects: dialect 'does_not_exist' is not loaded in the context
+func.func @pattern_test(%a: i32) -> i32 {
+  %c1 = arith.constant 1 : i32
+  %c2 = arith.constant 2 : i32
+  %add = arith.addi %a, %c1 : i32
+  %sub = arith.subi %add, %c2 : i32
+  return %sub : i32
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants