Skip to content

[MLIR] Initial upstream of polygeist dialect#199136

Open
wsmoses wants to merge 1 commit into
mainfrom
users/wsmoses/polygeist
Open

[MLIR] Initial upstream of polygeist dialect#199136
wsmoses wants to merge 1 commit into
mainfrom
users/wsmoses/polygeist

Conversation

@wsmoses
Copy link
Copy Markdown
Member

@wsmoses wsmoses commented May 21, 2026

Polygeist is an incubator project of LLVM (http://github.com/llvm/Polygeist), which contains many features, notably raising passes, including for backend retargettng, parallelism, and polyhedral analyses.

This PR is mostly just the creation of the dialect and upstreaming of first ops, pointer 2 memref and memref 2 pointer and relevant canonicalizations.

Passes and other operations to follow subsequently.

See https://discourse.llvm.org/t/rfc-add-scf-to-affine-conversion-pass-in-mlir/88036

@wsmoses wsmoses requested a review from ftynse May 21, 2026 23:44
@llvmorg-github-actions llvmorg-github-actions Bot added mlir bazel "Peripheral" support tier build system: utils/bazel labels May 21, 2026
@llvmorg-github-actions
Copy link
Copy Markdown

@llvm/pr-subscribers-mlir

Author: William Moses (wsmoses)

Changes

Polygeist is an incubator project of LLVM (http://github.com/llvm/Polygeist), which contains many features, notably raising passes, including for backend retargettng, parallelism, and polyhedral analyses.

This PR is mostly just the creation of the dialect and upstreaming of first ops, pointer 2 memref and memref 2 pointer and relevant canonicalizations.

Passes and other operations to follow subsequently.

See https://discourse.llvm.org/t/rfc-add-scf-to-affine-conversion-pass-in-mlir/88036


Patch is 29.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/199136.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/CMakeLists.txt (+1)
  • (added) mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt (+3)
  • (added) mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h (+14)
  • (added) mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td (+16)
  • (added) mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td (+42)
  • (modified) mlir/lib/Dialect/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Polygeist/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt (+19)
  • (added) mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp (+15)
  • (added) mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp (+380)
  • (modified) mlir/lib/RegisterAllDialects.cpp (+2)
  • (added) mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir (+24)
  • (added) mlir/test/Dialect/Polygeist/canonicalize-pointer2memref.mlir (+133)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+66)
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index d2505877e2dd0..1750bcfe1da54 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -27,6 +27,7 @@ add_subdirectory(OpenACCMPCommon)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
+add_subdirectory(Polygeist)
 add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..d21dc0f83acd1
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polygeist/IR/CMakeLists.txt
@@ -0,0 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS PolygeistOps.td)
+add_mlir_dialect(PolygeistOps polygeist)
+add_mlir_doc(PolygeistOps PolygeistOps Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h b/mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h
new file mode 100644
index 0000000000000..dc320e9093271
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h
@@ -0,0 +1,14 @@
+#ifndef MLIR_DIALECT_POLYGEIST_IR_POLYGEIST_H_
+#define MLIR_DIALECT_POLYGEIST_IR_POLYGEIST_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/Polygeist/IR/PolygeistOpsDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polygeist/IR/PolygeistOps.h.inc"
+
+#endif // MLIR_DIALECT_POLYGEIST_IR_POLYGEIST_H_
diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td
new file mode 100644
index 0000000000000..11b5e14606fda
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistBase.td
@@ -0,0 +1,16 @@
+#ifndef POLYGEIST_BASE
+#define POLYGEIST_BASE
+
+include "mlir/IR/OpBase.td"
+
+def Polygeist_Dialect : Dialect {
+  let name = "polygeist";
+  let cppNamespace = "::mlir::polygeist";
+  let summary = "The Polygeist dialect.";
+  let description = [{
+    The Polygeist dialect contains operations for raising low-level code to higher-level forms, and performing parallel and device transformations (including polyhedral).
+  }];
+  let useDefaultTypePrinterParser = 1;
+}
+
+#endif // POLYGEIST_BASE
diff --git a/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td
new file mode 100644
index 0000000000000..48d6c7c637264
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Polygeist/IR/PolygeistOps.td
@@ -0,0 +1,42 @@
+#ifndef POLYGEIST_OPS
+#define POLYGEIST_OPS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/Polygeist/IR/PolygeistBase.td"
+
+def Memref2PointerOp : Op<Polygeist_Dialect, "memref2pointer", [
+  ViewLikeOpInterface, Pure
+]> {
+  let summary = "Extract and LLVM pointer from a MemRef";
+
+  let arguments = (ins AnyMemRef : $source);
+  let results = (outs LLVM_AnyPointer:$result);
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+  
+  let extraClassDeclaration = [{
+    ::mlir::Value getViewSource() { return getSource(); }
+  }];
+}
+
+def Pointer2MemrefOp : Op<Polygeist_Dialect, "pointer2memref", [
+  ViewLikeOpInterface, Pure
+]> {
+  let summary = "Upgrade a pointer to a memref";
+
+  let arguments = (ins LLVM_AnyPointer:$source);
+  let results = (outs AnyMemRef : $result);
+
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
+  
+  let extraClassDeclaration = [{
+    ::mlir::Value getViewSource() { return getSource(); }
+  }];
+}
+
+#endif // POLYGEIST_OPS
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 66f68c369f81f..ee2838d8aaa70 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -27,6 +27,7 @@ add_subdirectory(OpenACCMPCommon)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)
 add_subdirectory(PDLInterp)
+add_subdirectory(Polygeist)
 add_subdirectory(Ptr)
 add_subdirectory(Quant)
 add_subdirectory(SCF)
diff --git a/mlir/lib/Dialect/Polygeist/CMakeLists.txt b/mlir/lib/Dialect/Polygeist/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/Polygeist/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt b/mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..647c83a18e6e3
--- /dev/null
+++ b/mlir/lib/Dialect/Polygeist/IR/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRPolygeistDialect
+  PolygeistOps.cpp
+  PolygeistDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Polygeist
+
+  DEPENDS
+  MLIRPolygeistOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialect
+  MLIRIR
+  MLIRMemRefDialect
+  MLIRLLVMDialect
+  MLIRArithDialect
+  MLIRAffineDialect
+  MLIRSCFDialect
+  )
diff --git a/mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp b/mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp
new file mode 100644
index 0000000000000..a50fe4e2aec01
--- /dev/null
+++ b/mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp
@@ -0,0 +1,15 @@
+#include "mlir/Dialect/Polygeist/IR/Polygeist.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
+
+using namespace mlir;
+using namespace mlir::polygeist;
+
+#include "mlir/Dialect/Polygeist/IR/PolygeistOpsDialect.cpp.inc"
+
+void PolygeistDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Polygeist/IR/PolygeistOps.cpp.inc"
+      >();
+}
diff --git a/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp b/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp
new file mode 100644
index 0000000000000..d070ba90570b5
--- /dev/null
+++ b/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp
@@ -0,0 +1,380 @@
+#include "mlir/Dialect/Polygeist/IR/Polygeist.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
+
+using namespace mlir;
+using namespace mlir::polygeist;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polygeist/IR/PolygeistOps.cpp.inc"
+
+namespace {
+/// Simplify pointer2memref(memref2pointer(x)) to cast(x)
+class Memref2Pointer2MemrefCast final
+    : public OpRewritePattern<Pointer2MemrefOp> {
+public:
+  using OpRewritePattern<Pointer2MemrefOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Pointer2MemrefOp op,
+                                PatternRewriter &rewriter) const override {
+    auto src = op.getSource().getDefiningOp<Memref2PointerOp>();
+    if (!src)
+      return failure();
+    auto smt = cast<MemRefType>(src.getSource().getType());
+    auto omt = cast<MemRefType>(op.getType());
+    if (smt.getShape().size() != omt.getShape().size())
+      return failure();
+    for (unsigned i = 1; i < smt.getShape().size(); i++) {
+      if (smt.getShape()[i] != omt.getShape()[i])
+        return failure();
+    }
+    if (smt.getElementType() != omt.getElementType())
+      return failure();
+    if (smt.getMemorySpace() != omt.getMemorySpace())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(),
+                                                src.getSource());
+    return success();
+  }
+};
+
+/// Simplify memref2pointer(pointer2memref(x)) to cast(x)
+class Memref2PointerBitCast final public OpRewritePattern<LLVM::BitcastOp> {
+public:
+  using OpRewritePattern<LLVM::BitcastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(LLVM::BitcastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto src = op.getOperand().getDefiningOp<Memref2PointerOp>();
+    if (!src)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<Memref2PointerOp>(op, op.getType(),
+                                                  src.getSource());
+    return success();
+  }
+};
+
+/// Simplify load(pointer2memref(gep(...(x)))) to load(x, idx)
+template <typename T>
+class LoadStorePointer2MemrefGEP final : public OpRewritePattern<T> {
+public:
+  using OpRewritePattern<T>::OpRewritePattern;
+
+  SmallVector<Value> newIndex(T op, Value finalIndex,
+                              PatternRewriter &rewriter) const;
+
+  void createNewOp(T op, Value baseMemref, SmallVector<Value> vals,
+                   PatternRewriter &rewriter) const;
+
+  Value getMemref(T op) const;
+
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getMemRefType().getRank() != 1)
+      return failure();
+
+    auto src =
+        getMemref(op).template getDefiningOp<Pointer2MemrefOp>();
+    if (!src)
+      return failure();
+
+    Type elementType = op.getMemRefType().getElementType();
+    unsigned elementSize = elementType.isIntOrFloat()
+                               ? elementType.getIntOrFloatBitWidth() / 8
+                               : 0;
+    if (elementSize == 0)
+      return failure();
+
+    SmallVector<std::pair<LLVM::GEPOp, unsigned>> gepOps;
+    Value ptr = src.getSource();
+
+    while (auto gep = ptr.getDefiningOp<LLVM::GEPOp>()) {
+      if (gep.getIndices().size() != 1)
+        break;
+
+      unsigned gepElemSize = 1;
+      auto elemTy = gep.getElemType();
+      if (elemTy.isIntOrFloat()) {
+        gepElemSize = elemTy.getIntOrFloatBitWidth() / 8;
+      } else if (auto arrayTy = dyn_cast<LLVM::LLVMArrayType>(elemTy)) {
+        auto baseTy = arrayTy.getElementType();
+        if (baseTy.isIntOrFloat()) {
+          gepElemSize =
+              (baseTy.getIntOrFloatBitWidth() / 8) * arrayTy.getNumElements();
+        } else {
+          break;
+        }
+      } else {
+        break;
+      }
+
+      gepOps.emplace_back(gep, gepElemSize);
+      ptr = gep.getBase();
+    }
+
+    if (gepOps.empty())
+      return failure();
+
+    Location loc = op.getLoc();
+    auto baseMemref = rewriter.create<Pointer2MemrefOp>(
+        loc, cast<MemRefType>(src.getType()), ptr);
+
+    Value finalIndex = nullptr;
+    for (auto [gep, gepElemSize] : llvm::reverse(gepOps)) {
+      PointerUnion<IntegerAttr, Value> rawIdx = gep.getIndices()[0];
+      Value idx = dyn_cast_if_present<Value>(rawIdx);
+      if (!idx)
+        idx = rewriter.create<arith::ConstantIndexOp>(
+            loc, cast<IntegerAttr>(rawIdx).getValue().getSExtValue());
+      
+      if (auto constIdx = idx.getDefiningOp<arith::ConstantIndexOp>()) {
+        if ((constIdx.value() * gepElemSize) % elementSize != 0) {
+          return failure();
+        }
+      }
+
+      if (!idx.getType().isIndex()) {
+        idx = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+                                                 idx);
+      }
+
+      unsigned gcd = std::gcd(gepElemSize, elementSize);
+      unsigned scaledGep = gepElemSize / gcd;
+      unsigned scaledElement = elementSize / gcd;
+
+      Value scaledIdx =
+          (scaledGep != 1)
+              ? rewriter.create<arith::MulIOp>(
+                    loc, idx,
+                    rewriter.create<arith::ConstantIndexOp>(loc, scaledGep))
+              : idx;
+
+      Value elemOffset =
+          (scaledElement != 1)
+              ? rewriter.create<arith::DivSIOp>(loc, scaledIdx,
+                                       rewriter.create<arith::ConstantIndexOp>(
+                                           loc, scaledElement))
+              : scaledIdx;
+
+      if (finalIndex)
+        finalIndex =
+            rewriter.create<arith::AddIOp>(loc, finalIndex, elemOffset);
+      else
+        finalIndex = elemOffset;
+    }
+
+    createNewOp(op, baseMemref, newIndex(op, finalIndex, rewriter), rewriter);
+    return success();
+  }
+};
+
+template <>
+Value LoadStorePointer2MemrefGEP<memref::LoadOp>::getMemref(
+    memref::LoadOp op) const {
+  return op.getMemref();
+}
+
+template <>
+Value LoadStorePointer2MemrefGEP<memref::StoreOp>::getMemref(
+    memref::StoreOp op) const {
+  return op.getMemref();
+}
+
+template <>
+Value LoadStorePointer2MemrefGEP<affine::AffineLoadOp>::getMemref(
+    affine::AffineLoadOp op) const {
+  return op.getMemref();
+}
+
+template <>
+Value LoadStorePointer2MemrefGEP<affine::AffineStoreOp>::getMemref(
+    affine::AffineStoreOp op) const {
+  return op.getMemref();
+}
+
+template <>
+SmallVector<Value> LoadStorePointer2MemrefGEP<memref::LoadOp>::newIndex(
+    memref::LoadOp op, Value finalIndex, PatternRewriter &rewriter) const {
+  auto operands = llvm::to_vector(op.getIndices());
+  operands[0] =
+      rewriter.create<arith::AddIOp>(op.getLoc(), operands[0], finalIndex);
+  return operands;
+}
+
+template <>
+SmallVector<Value> LoadStorePointer2MemrefGEP<affine::AffineLoadOp>::newIndex(
+    affine::AffineLoadOp op, Value finalIndex,
+    PatternRewriter &rewriter) const {
+  auto apply = rewriter.create<affine::AffineApplyOp>(
+      op.getLoc(), op.getAffineMap(), op.getMapOperands());
+
+  SmallVector<Value> operands;
+  for (auto op : apply->getResults())
+    operands.push_back(op);
+  operands[0] =
+      rewriter.create<arith::AddIOp>(op.getLoc(), operands[0], finalIndex);
+  return operands;
+}
+
+template <>
+SmallVector<Value> LoadStorePointer2MemrefGEP<memref::StoreOp>::newIndex(
+    memref::StoreOp op, Value finalIndex, PatternRewriter &rewriter) const {
+  auto operands = llvm::to_vector(op.getIndices());
+  operands[0] =
+      rewriter.create<arith::AddIOp>(op.getLoc(), operands[0], finalIndex);
+  return operands;
+}
+
+template <>
+SmallVector<Value> LoadStorePointer2MemrefGEP<affine::AffineStoreOp>::newIndex(
+    affine::AffineStoreOp op, Value finalIndex,
+    PatternRewriter &rewriter) const {
+  auto apply = rewriter.create<affine::AffineApplyOp>(
+      op.getLoc(), op.getAffineMap(), op.getMapOperands());
+
+  SmallVector<Value> operands;
+  for (auto op : apply->getResults())
+    operands.push_back(op);
+  operands[0] =
+      rewriter.create<arith::AddIOp>(op.getLoc(), operands[0], finalIndex);
+  return operands;
+}
+
+template <>
+void LoadStorePointer2MemrefGEP<memref::LoadOp>::createNewOp(
+    memref::LoadOp op, Value baseMemref, SmallVector<Value> idxs,
+    PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, baseMemref, idxs);
+}
+
+template <>
+void LoadStorePointer2MemrefGEP<affine::AffineLoadOp>::createNewOp(
+    affine::AffineLoadOp op, Value baseMemref, SmallVector<Value> idxs,
+    PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, baseMemref, idxs);
+}
+
+template <>
+void LoadStorePointer2MemrefGEP<memref::StoreOp>::createNewOp(
+    memref::StoreOp op, Value baseMemref, SmallVector<Value> idxs,
+    PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, op.getValue(), baseMemref,
+                                               idxs);
+}
+
+template <>
+void LoadStorePointer2MemrefGEP<affine::AffineStoreOp>::createNewOp(
+    affine::AffineStoreOp op, Value baseMemref, SmallVector<Value> idxs,
+    PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, op.getValue(), baseMemref,
+                                               idxs);
+}
+
+/// Simplify cast(pointer2memref(x)) to pointer2memref(x)
+class Pointer2MemrefCast final public OpRewritePattern<memref::CastOp> {
+public:
+  using OpRewritePattern<memref::CastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::CastOp op,
+                                PatternRewriter &rewriter) const override {
+    auto src = op.getSource().getDefiningOp<Pointer2MemrefOp>();
+    if (!src)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<Pointer2MemrefOp>(op, op.getType(),
+                                                             src.getSource());
+    return success();
+  }
+};
+
+/// Simplify memref2pointer(pointer2memref(x)) to cast(x)
+class Pointer2Memref2PointerCast final
+    : public OpRewritePattern<Memref2PointerOp> {
+public:
+  using OpRewritePattern<Memref2PointerOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(Memref2PointerOp op,
+                                PatternRewriter &rewriter) const override {
+    auto src = op.getSource().getDefiningOp<Pointer2MemrefOp>();
+    if (!src)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, op.getType(),
+                                                 src.getSource());
+    return success();
+  }
+};
+
+} // namespace
+
+void Memref2PointerOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                   MLIRContext *context) {
+  results.insert<Memref2Pointer2MemrefCast, Memref2PointerBitCast>(context);
+}
+
+OpFoldResult Memref2PointerOp::fold(FoldAdaptor adaptor) {
+  /// Simplify memref2pointer(cast(x)) to memref2pointer(x)
+  if (auto mc = getSource().getDefiningOp<memref::CastOp>()) {
+    getSourceMutable().assign(mc.getSource());
+    return getResult();
+  }
+  if (auto mc = getSource().getDefiningOp<Pointer2MemrefOp>()) {
+    if (mc.getSource().getType() == getType()) {
+      return mc.getSource();
+    }
+  }
+  return nullptr;
+}
+
+void Pointer2MemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                   MLIRContext *context) {
+  results.insert<Pointer2MemrefCast, Pointer2Memref2PointerCast,
+                 LoadStorePointer2MemrefGEP<memref::LoadOp>,
+                 LoadStorePointer2MemrefGEP<affine::AffineLoadOp>,
+                 LoadStorePointer2MemrefGEP<memref::StoreOp>,
+                 LoadStorePointer2MemrefGEP<affine::AffineStoreOp>>(context);
+}
+
+OpFoldResult Pointer2MemrefOp::fold(FoldAdaptor adaptor) {
+  /// Simplify pointer2memref(cast(x)) to pointer2memref(x)
+  if (auto mc = getSource().getDefiningOp<LLVM::BitcastOp>()) {
+    getSourceMutable().assign(mc.getOperand());
+    return getResult();
+  }
+  if (auto mc = getSource().getDefiningOp<LLVM::AddrSpaceCastOp>()) {
+    getSourceMutable().assign(mc.getOperand());
+    return getResult();
+  }
+  if (auto mc = getSource().getDefiningOp<LLVM::GEPOp>()) {
+    for (auto idx : mc.getDynamicIndices()) {
+      assert(idx);
+      if (!matchPattern(idx, m_Zero()))
+        return nullptr;
+    }
+    auto staticIndices = mc.getRawConstantIndices();
+    for (auto pair : llvm::enumerate(staticIndices)) {
+      if (pair.value() != LLVM::GEPOp::kDynamicIndex)
+        if (pair.value() != 0)
+          return nullptr;
+    }
+
+    getSourceMutable().assign(mc.getBase());
+    return getResult();
+  }
+  if (auto mc = getSource().getDefiningOp<Memref2PointerOp>()) {
+    if (mc.getSource().getType() == getType()) {
+      return mc.getSource();
+    }
+  }
+  return nullptr;
+}
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 2f55296f424cd..456ea0610940d 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -66,6 +66,7 @@
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Polygeist/IR/Polygeist.h"
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -140,6 +141,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
                   omp::OpenMPDialect,
                   pdl::PDLDialect,
                   pdl_interp::PDLInterpDialect,
+                  polygeist::PolygeistDialect,
                   ptr::PtrDialect,
                   quant::QuantDialect,
                   ROCDL::ROCDLDialect,
diff --git a/mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir b/mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir
new file mode 100644
index 0000000000000..9e86d78341879
--- /dev/null
+++ b/mlir/test/Dialect/Polygeist/canonicalize-memref2pointer.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt --canonicalize -split-input-file %s | FileCheck %s
+
+// CH...
[truncated]

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 21, 2026

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Dialect/Polygeist/IR/Polygeist.h mlir/lib/Dialect/Polygeist/IR/PolygeistDialect.cpp mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp mlir/lib/RegisterAllDialects.cpp --diff_from_common_commit

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp b/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp
index 137ae4b65..79af8cda8 100644
--- a/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp
+++ b/mlir/lib/Dialect/Polygeist/IR/PolygeistOps.cpp
@@ -144,7 +144,7 @@ public:
 
       if (!idx.getType().isIndex()) {
         idx = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
-                                                  idx);
+                                         idx);
       }
 
       unsigned gcd = std::gcd(gepElemSize, elementSize);
@@ -160,9 +160,9 @@ public:
 
       Value elemOffset =
           (scaledElement != 1)
-              ? arith::DivSIOp::create(
-                    rewriter, loc, scaledIdx,
-                    arith::ConstantIndexOp::create(rewriter, loc, scaledElement))
+              ? arith::DivSIOp::create(rewriter, loc, scaledIdx,
+                                       arith::ConstantIndexOp::create(
+                                           rewriter, loc, scaledElement))
               : scaledIdx;
 
       if (finalIndex)

@wsmoses wsmoses force-pushed the users/wsmoses/polygeist branch 4 times, most recently from d335342 to 6278e64 Compare May 21, 2026 23:56
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 22, 2026

🐧 Linux x64 Test Results

  • 8085 tests passed
  • 618 tests skipped

✅ The build succeeded and all tests passed.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 22, 2026

🪟 Windows x64 Test Results

  • 3774 tests passed
  • 429 tests skipped

✅ The build succeeded and all tests passed.

@wsmoses wsmoses force-pushed the users/wsmoses/polygeist branch from 6278e64 to 46069c5 Compare May 22, 2026 00:37
@wsmoses wsmoses requested a review from ivanradanov May 22, 2026 00:45
@wsmoses wsmoses force-pushed the users/wsmoses/polygeist branch from 46069c5 to b34d81e Compare May 22, 2026 00:47
@wsmoses wsmoses force-pushed the users/wsmoses/polygeist branch from b34d81e to d4179f6 Compare May 22, 2026 00:50
@wsmoses wsmoses requested a review from chelini May 22, 2026 00:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bazel "Peripheral" support tier build system: utils/bazel mlir polygeist

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant