Skip to content

Commit

Permalink
Move Python Bindings from being defined with Pybind -> Nanobind (#2379)
Browse files Browse the repository at this point in the history
### Ticket
- PR closes #1774 

### Problem description
- `nanobind` is now supported through MLIR. It's hopefully leaner,
faster, and better supported.
- We should switch our Python bindings over to `nanobind` ASAP to
prevent further complications

### What's changed
- Few functional changes to nanobind to support new python object
storage methods.
- Switched all of the bindings and build system from `pybind11` ->
`nanobind`.
  • Loading branch information
vprajapati-tt authored Mar 7, 2025
1 parent b005303 commit 191d311
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 334 deletions.
38 changes: 20 additions & 18 deletions include/ttmlir/Bindings/Python/TTMLIRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#define TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
Expand All @@ -21,48 +22,49 @@
#include "ttmlir/RegisterAll.h"
#include "llvm/Support/CommandLine.h"

#include <nanobind/stl/variant.h>
#include <variant>

namespace py = pybind11;
namespace nb = nanobind;

namespace mlir::ttmlir::python {

template <typename T>
py::class_<T> tt_attribute_class(py::module &m, const char *class_name) {
py::class_<T> cls(m, class_name);
nb::class_<T> tt_attribute_class(nb::module_ &m, const char *class_name) {
nb::class_<T> cls(m, class_name);
cls.def_static("maybe_downcast",
[](MlirAttribute attr) -> std::variant<T, py::object> {
[](MlirAttribute attr) -> std::variant<T, nb::object> {
auto res = mlir::dyn_cast<T>(unwrap(attr));
if (res) {
return res;
}
return py::none();
return nb::none();
});
return cls;
}

template <typename T>
py::class_<T> tt_type_class(py::module &m, const char *class_name) {
py::class_<T> cls(m, class_name);
nb::class_<T> tt_type_class(nb::module_ &m, const char *class_name) {
nb::class_<T> cls(m, class_name);
cls.def_static("maybe_downcast",
[](MlirType type) -> std::variant<T, py::object> {
[](MlirType type) -> std::variant<T, nb::object> {
auto res = mlir::dyn_cast<T>(unwrap(type));
if (res) {
return res;
}
return py::none();
return nb::none();
});
return cls;
}

void populateTTModule(py::module &m);
void populateTTIRModule(py::module &m);
void populateTTKernelModule(py::module &m);
void populateTTNNModule(py::module &m);
void populateOverridesModule(py::module &m);
void populateOptimizerOverridesModule(py::module &m);
void populatePassesModule(py::module &m);
void populateUtilModule(py::module &m);
void populateTTModule(nb::module_ &m);
void populateTTIRModule(nb::module_ &m);
void populateTTKernelModule(nb::module_ &m);
void populateTTNNModule(nb::module_ &m);
void populateOverridesModule(nb::module_ &m);
void populateOptimizerOverridesModule(nb::module_ &m);
void populatePassesModule(nb::module_ &m);
void populateUtilModule(nb::module_ &m);
} // namespace mlir::ttmlir::python

#endif // TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H
11 changes: 6 additions & 5 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ class OptimizerOverridesHandler {

// Wrapper methods we use to expose the adders to the python bindings
std::unordered_map<std::string, InputLayoutOverrideParams>
getInputLayoutOverridesPybindWrapper() const;
getInputLayoutOverridesNanobindWrapper() const;
std::unordered_map<std::string, OutputLayoutOverrideParams>
getOutputLayoutOverridesPybindWrapper() const;
getOutputLayoutOverridesNanobindWrapper() const;

// Wrapper methods we use to expose the adders to the python bindings
void addInputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &);
void addOutputLayoutOverridePybindWrapper(std::string,
OutputLayoutOverrideParams);
void addInputLayoutOverrideNanobindWrapper(std::string,
std::vector<int64_t> &);
void addOutputLayoutOverrideNanobindWrapper(std::string,
OutputLayoutOverrideParams);

private:
// Flags for enabling/disabling the optimizer passes
Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct GoldenTensor {
std::vector<std::uint8_t> &&_data)
: name(name), shape(shape), strides(strides), dtype(dtype),
data(std::move(_data)) {}

// Create an explicit empty constructor
GoldenTensor() = default;
};

inline ::tt::target::OOBVal toFlatbuffer(FlatbufferObjectCache &,
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ OptimizerOverridesHandler::getOutputLayoutOverrides() const {
}

std::unordered_map<std::string, InputLayoutOverrideParams>
OptimizerOverridesHandler::getInputLayoutOverridesPybindWrapper() const {
OptimizerOverridesHandler::getInputLayoutOverridesNanobindWrapper() const {
std::unordered_map<std::string, InputLayoutOverrideParams>
inputLayoutOverridesWrapper;
for (auto &entry : inputLayoutOverrides) {
Expand All @@ -93,7 +93,7 @@ OptimizerOverridesHandler::getInputLayoutOverridesPybindWrapper() const {
}

std::unordered_map<std::string, OutputLayoutOverrideParams>
OptimizerOverridesHandler::getOutputLayoutOverridesPybindWrapper() const {
OptimizerOverridesHandler::getOutputLayoutOverridesNanobindWrapper() const {
std::unordered_map<std::string, OutputLayoutOverrideParams>
outputLayoutOverridesWrapper;
for (auto &entry : outputLayoutOverrides) {
Expand Down Expand Up @@ -190,15 +190,15 @@ void OptimizerOverridesHandler::addOutputLayoutOverride(
std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType};
}

void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper(
void OptimizerOverridesHandler::addInputLayoutOverrideNanobindWrapper(
std::string opName, std::vector<int64_t> &operandIdxes) {
StringRef opNameStringRef(opName);
SmallVector<int64_t> operandIdxesSmallVector(operandIdxes.begin(),
operandIdxes.end());
addInputLayoutOverride(opNameStringRef, operandIdxesSmallVector);
}

void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper(
void OptimizerOverridesHandler::addOutputLayoutOverrideNanobindWrapper(
std::string opName, OutputLayoutOverrideParams overrideParams) {
StringRef opNameStringRef(opName);
addOutputLayoutOverride(opNameStringRef, overrideParams);
Expand Down
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main
MLIRTestToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
MLIRX86VectorToLLVMIRTranslation
PYTHON_BINDINGS_LIBRARY nanobind
)

set(TTMLIR_PYTHON_SOURCES
Expand Down
69 changes: 35 additions & 34 deletions python/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

namespace mlir::ttmlir::python {

void populateOptimizerOverridesModule(py::module &m) {
void populateOptimizerOverridesModule(nb::module_ &m) {

py::class_<tt::ttnn::OptimizerOverridesHandler>(m,
nb::class_<tt::ttnn::OptimizerOverridesHandler>(m,
"OptimizerOverridesHandler")
.def(py::init<>())
.def(nb::init<>())

.def("set_enable_optimizer",
&tt::ttnn::OptimizerOverridesHandler::setEnableOptimizer)
Expand Down Expand Up @@ -56,47 +56,48 @@ void populateOptimizerOverridesModule(py::module &m) {

.def("get_input_layout_overrides",
&tt::ttnn::OptimizerOverridesHandler::
getInputLayoutOverridesPybindWrapper)
getInputLayoutOverridesNanobindWrapper)
.def("get_output_layout_overrides",
&tt::ttnn::OptimizerOverridesHandler::
getOutputLayoutOverridesPybindWrapper)
getOutputLayoutOverridesNanobindWrapper)

.def("add_input_layout_override", &tt::ttnn::OptimizerOverridesHandler::
addInputLayoutOverridePybindWrapper)
.def("add_input_layout_override",
&tt::ttnn::OptimizerOverridesHandler::
addInputLayoutOverrideNanobindWrapper)
.def("add_output_layout_override",
&tt::ttnn::OptimizerOverridesHandler::
addOutputLayoutOverridePybindWrapper)
addOutputLayoutOverrideNanobindWrapper)

.def("to_string", &tt::ttnn::OptimizerOverridesHandler::toString);

py::enum_<mlir::tt::MemoryLayoutAnalysisPolicyType>(
nb::enum_<mlir::tt::MemoryLayoutAnalysisPolicyType>(
m, "MemoryLayoutAnalysisPolicyType")
.value("DFSharding", mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding)
.value("GreedyL1Interleaved",
mlir::tt::MemoryLayoutAnalysisPolicyType::GreedyL1Interleaved)
.value("BFInterleaved",
mlir::tt::MemoryLayoutAnalysisPolicyType::BFInterleaved);

py::enum_<mlir::tt::ttnn::BufferType>(m, "BufferType")
nb::enum_<mlir::tt::ttnn::BufferType>(m, "BufferType")
.value("DRAM", mlir::tt::ttnn::BufferType::DRAM)
.value("L1", mlir::tt::ttnn::BufferType::L1)
.value("SystemMemory", mlir::tt::ttnn::BufferType::SystemMemory)
.value("L1Small", mlir::tt::ttnn::BufferType::L1Small)
.value("Trace", mlir::tt::ttnn::BufferType::Trace);

py::enum_<mlir::tt::ttnn::Layout>(m, "Layout")
nb::enum_<mlir::tt::ttnn::Layout>(m, "Layout")
.value("RowMajor", mlir::tt::ttnn::Layout::RowMajor)
.value("Tile", mlir::tt::ttnn::Layout::Tile)
.value("Invalid", mlir::tt::ttnn::Layout::Invalid);

py::enum_<mlir::tt::ttnn::TensorMemoryLayout>(m, "TensorMemoryLayout")
nb::enum_<mlir::tt::ttnn::TensorMemoryLayout>(m, "TensorMemoryLayout")
.value("Interleaved", mlir::tt::ttnn::TensorMemoryLayout::Interleaved)
.value("SingleBank", mlir::tt::ttnn::TensorMemoryLayout::SingleBank)
.value("HeightSharded", mlir::tt::ttnn::TensorMemoryLayout::HeightSharded)
.value("WidthSharded", mlir::tt::ttnn::TensorMemoryLayout::WidthSharded)
.value("BlockSharded", mlir::tt::ttnn::TensorMemoryLayout::BlockSharded);

py::enum_<mlir::tt::DataType>(m, "DataType")
nb::enum_<mlir::tt::DataType>(m, "DataType")
.value("Float32", mlir::tt::DataType::Float32)
.value("Float16", mlir::tt::DataType::Float16)
.value("BFloat16", mlir::tt::DataType::BFloat16)
Expand All @@ -111,10 +112,10 @@ void populateOptimizerOverridesModule(py::module &m) {
.value("UInt8", mlir::tt::DataType::UInt8)
.value("Int32", mlir::tt::DataType::Int32);

py::class_<mlir::tt::ttnn::InputLayoutOverrideParams>(
nb::class_<mlir::tt::ttnn::InputLayoutOverrideParams>(
m, "InputLayoutOverrideParams")
.def(py::init<>())
.def_property(
.def(nb::init<>())
.def_prop_rw(
"operand_idxes",
[](const mlir::tt::ttnn::InputLayoutOverrideParams &obj) {
// Getter: Convert SmallVector to std::vector
Expand All @@ -128,10 +129,10 @@ void populateOptimizerOverridesModule(py::module &m) {
obj.operandIdxes.append(input.begin(), input.end());
});

py::class_<mlir::tt::ttnn::OutputLayoutOverrideParams>(
nb::class_<mlir::tt::ttnn::OutputLayoutOverrideParams>(
m, "OutputLayoutOverrideParams")
.def(py::init<>())
.def_property(
.def(nb::init<>())
.def_prop_rw(
"grid",
[](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) {
// Getter: Convert SmallVector to std::vector
Expand All @@ -151,20 +152,20 @@ void populateOptimizerOverridesModule(py::module &m) {
}
obj.grid->append(input.begin(), input.end());
})
.def_readwrite("buffer_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::bufferType)
.def_readwrite(
"tensor_memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::tensorMemoryLayout)
.def_readwrite("memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout)
.def_readwrite("data_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType)
.def_rw("buffer_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::bufferType)
.def_rw("tensor_memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::tensorMemoryLayout)
.def_rw("memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout)
.def_rw("data_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType)
.def("set_buffer_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto bufferType = mlir::tt::ttnn::symbolizeBufferType(value)) {
obj.bufferType = bufferType;
if (auto bufferType_ =
mlir::tt::ttnn::symbolizeBufferType(value)) {
obj.bufferType = bufferType_;
} else {
throw std::invalid_argument("Invalid buffer type: " + value);
}
Expand All @@ -183,17 +184,17 @@ void populateOptimizerOverridesModule(py::module &m) {
.def("set_memory_layout_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(value)) {
obj.memoryLayout = memoryLayout;
if (auto memoryLayout_ = mlir::tt::ttnn::symbolizeLayout(value)) {
obj.memoryLayout = memoryLayout_;
} else {
throw std::invalid_argument("Invalid memory layout: " + value);
}
})
.def("set_data_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto dataType = mlir::tt::DataTypeStringToEnum(value)) {
obj.dataType = dataType;
if (auto dataType_ = mlir::tt::DataTypeStringToEnum(value)) {
obj.dataType = dataType_;
} else {
throw std::invalid_argument("Invalid data type: " + value);
}
Expand Down
4 changes: 2 additions & 2 deletions python/Overrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

namespace mlir::ttmlir::python {

void populateOverridesModule(py::module &m) {
void populateOverridesModule(nb::module_ &m) {

m.def(
"get_ptr", [](void *op) { return reinterpret_cast<uintptr_t>(op); },
py::arg("op").noconvert());
nb::arg("op").noconvert());
}

} // namespace mlir::ttmlir::python
Loading

0 comments on commit 191d311

Please sign in to comment.