From d0078a593080b84682609bab16716006a16616f6 Mon Sep 17 00:00:00 2001 From: Aleksandar Zecevic Date: Thu, 6 Mar 2025 15:04:20 +0100 Subject: [PATCH] TTNN->EmitC transition existing ops to new converter (#2345) ### Ticket https://github.com/tenstorrent/tt-mlir/issues/2343 ### Problem description Part of the effort to onboard all TTNN ops to conversion infrastructure introduced in https://github.com/tenstorrent/tt-mlir/pull/2331 ### What's changed Transitioned ops with existing conversion to conversion with the new converter. Few changes in the Emitter class, mainly in the way how operands are treated, to accommodate more complex cases like `Variadic` of tensors that maps to `std::vector`. --- .../Conversion/TTNNToEmitC/EmitCConversion.h | 114 +++- lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 607 ++++++------------ 2 files changed, 281 insertions(+), 440 deletions(-) diff --git a/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h b/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h index f41a7126f0..55d022e07f 100644 --- a/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h +++ b/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h @@ -5,6 +5,7 @@ #ifndef TTMLIR_CONVERSION_TTNNTOEMITC_EMITCCONVERSION_H #define TTMLIR_CONVERSION_TTNNTOEMITC_EMITCCONVERSION_H +#include "ttmlir/Conversion/TTNNToEmitC/Utils.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" @@ -30,6 +31,10 @@ template struct SmallVector { using value_type = T; }; + +struct IDevice; + +struct Tensor; } // namespace ttnn namespace mlir { @@ -106,6 +111,16 @@ struct TypeName<::ttnn::SmallVector> { "::ttnn::SmallVector<" + TypeNameV + ">"; }; +template <> +struct TypeName<::ttnn::IDevice> { + inline static const std::string value = "::ttnn::IDevice"; +}; + +template <> +struct TypeName<::ttnn::Tensor> { + inline static const std::string value = "::ttnn::Tensor"; +}; + template struct EmitCTypeConverter; @@ -570,6 +585,15 @@ inline std::string convert(ttnn::MemoryConfigAttr attr) { return buf; } +template +struct IsMLIRType { + static constexpr bool value = std::is_convertible_v || + std::is_convertible_v; +}; + +template +static constexpr bool IsMLIRTypeV = IsMLIRType::value; + template class EmitCTTNNEmitter { public: @@ -610,7 +634,7 @@ class EmitCTTNNEmitter { template mlir::Attribute emit(std::optional attr) { if (!attr) { - return rewriter.getType(TypeNameV); + return emit(std::nullopt); } if constexpr (std::is_void_v) { @@ -624,20 +648,47 @@ class EmitCTTNNEmitter { return rewriter.getType(TypeNameV); } - mlir::Attribute emit(Value val) { + mlir::Attribute emit(mlir::Value val) { if (!val) { - return rewriter.getType(TypeNameV); + return emit(std::nullopt); } - auto operand = - llvm::find_if(op->getOpOperands(), [&](OpOperand &opOperand) { - return opOperand.get() == val; - }); - if (operand == op->getOpOperands().end()) { - llvm_unreachable("Unknown operand"); + mlir::OpOperand *opOperand = + std::find_if(std::next(op->getOpOperands().begin(), operands.size()), + op->getOpOperands().end(), + [&](OpOperand &operand) { return operand.get() == val; }); + + unsigned index = opOperand->getOperandNumber(); + operands.push_back(adaptor.getOperands()[index]); + return rewriter.getIndexAttr(index); + } + + mlir::Attribute emit(mlir::Operation::operand_range operands) { + for (mlir::OpOperand &opOperand : op->getOpOperands()) { + auto begin = + std::next(op->getOperands().begin(), opOperand.getOperandNumber()); + if (mlir::Operation::operand_range( + begin, std::next(begin, operands.size())) != operands) { + continue; + } + unsigned index = opOperand.getOperandNumber(); + llvm::SmallVector values( + adaptor.getOperands().begin() + index, + adaptor.getOperands().begin() + index + operands.size()); + this->operands.push_back(createVector(values)); + return rewriter.getIndexAttr(index); } + llvm_unreachable("Invalid operand range"); + } - return rewriter.getIndexAttr(operand->getOperandNumber()); + template + mlir::Attribute emit(std::nullptr_t) { + if constexpr (std::is_void_v) { + return rewriter.getType("nullptr"); + } else { + return rewriter.getType( + "static_cast<" + TypeNameV + " *>(nullptr)"); + } } // Handles the case when source type is convertible to mlir::Attribute type @@ -651,7 +702,7 @@ class EmitCTTNNEmitter { if (auto convertedValue = EmitCTypeConverter::convert(attr)) { return rewriter.getType(*convertedValue); } - return rewriter.getType(TypeNameV); + return emit(std::nullopt); } // Handles the case when source type is a non mlir::Attribute convertible type @@ -661,10 +712,9 @@ class EmitCTTNNEmitter { // appropriate C++ type. // TODO (azecevic): See if we can simplify the condition for this overload // instantiation. - template - std::enable_if_t && - !std::is_convertible_v, - mlir::Attribute> + template >> + std::enable_if_t, mlir::Attribute> emit(SourceTy &&attr) { auto result = EmitCTypeConverter::convert(std::forward(attr)); @@ -681,35 +731,51 @@ class EmitCTTNNEmitter { template emitc::CallOpaqueOp replaceOp(OpConversionPatternTy &&opConversionPattern, llvm::ArrayRef args) { + auto resultTypes = llvm::to_vector( + llvm::map_range(op->getResultTypes(), [&](Type type) -> Type { + return opConversionPattern.getTypeConverter()->convertType(type); + })); return rewriter.replaceOpWithNewOp( - op, - opConversionPattern.getTypeConverter()->convertType( - op->getResult(0).getType()), - opConversionPattern.convertOpName(op), rewriter.getArrayAttr(args), - nullptr, adaptor.getOperands()); + op, resultTypes, opConversionPattern.convertOpName(op), + rewriter.getArrayAttr(args), nullptr, operands); } private: + mlir::Value createVector(ValueRange operands) { + tt::ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, op); + + return rewriter + .create( + op.getLoc(), + emitc::OpaqueType::get(rewriter.getContext(), + TypeNameV>), + tt::ttnn_to_emitc::utils::kCreateVectorFunctionName, nullptr, + nullptr, operands) + ->getResult(0); + } + TTNNOp op; OpAdaptor adaptor; ConversionPatternRewriter &rewriter; + llvm::SmallVector operands; }; +} // namespace ttnn_to_emitc +} // namespace tt + // Helper function that serves as an alternative to the // `emit>` member function of the `EmitCTTNNEmitter` class. // For example, instead of calling `emit>(attr)`, // one can call `emit(attr) | emit(attr)`. inline mlir::Attribute operator|(mlir::Attribute lhs, mlir::Attribute rhs) { - static const mlir::Attribute nulloptAttr = - emitc::OpaqueAttr::get(lhs.getContext(), TypeNameV); + static const mlir::Attribute nulloptAttr = emitc::OpaqueAttr::get( + lhs.getContext(), tt::ttnn_to_emitc::TypeNameV); if (!lhs || lhs == nulloptAttr) { return rhs; } return lhs; } -} // namespace ttnn_to_emitc -} // namespace tt } // namespace mlir #endif // TTMLIR_CONVERSION_TTNNTOEMITC_EMITCCONVERSION_H diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 1e25aa425b..af23a6a7b2 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -133,15 +133,16 @@ class EltwiseUnaryOpConversionPattern LogicalResult matchAndRewrite(SourceOp srcOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, rewriter); + llvm::SmallVector args{ emitter.emit(srcOp.getInputs()[0]), emitter.emit(std::nullopt), }; + emitter.replaceOp(*this, args); + return success(); } }; @@ -164,18 +165,16 @@ class EltwiseUnaryWithFastAndApproximateModeOpConversionPattern LogicalResult matchAndRewrite(SourceOp srcOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - tt::ttnn_to_emitc::utils::convertBoolAttr( - rewriter, BoolAttr::get(rewriter.getContext(), false)), - ttnn_to_emitc::utils::createStdNullopt(rewriter)}); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, rewriter); + + llvm::SmallVector args{ + emitter.emit(srcOp.getInputs()[0]), + /*parameter=*/emitter.emit(false), + /*memory_config=*/emitter.emit(std::nullopt), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -198,16 +197,15 @@ class EltwiseUnaryCompositeOpConversionPattern LogicalResult matchAndRewrite(SourceOp srcOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter)}); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, rewriter); + + llvm::SmallVector args{ + emitter.emit(srcOp.getInputs()[0]), + /*memory_config=*/emitter.emit(std::nullopt), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -232,20 +230,16 @@ class EltwiseBinaryOpConversionPattern matchAndRewrite(SourceOp srcOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this - // - llvm::SmallVector attrs; - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0)); - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1)); - attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter)); - attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter)); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, rewriter); - ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs); + llvm::SmallVector args{ + emitter.emit(srcOp.getInputs()[0]), + emitter.emit(srcOp.getInputs()[1]), + /*dtype=*/emitter.emit(std::nullopt), + /*memory_config=*/emitter.emit(std::nullopt), + }; - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + emitter.replaceOp(*this, args); return success(); } @@ -262,34 +256,22 @@ class LinearOpConversionPattern tt::ttnn::LinearOp>::TTNNToEmitCBaseOpConversionPattern; LogicalResult - matchAndRewrite(tt::ttnn::LinearOp linearOp, - tt::ttnn::LinearOp::Adaptor adaptor, + matchAndRewrite(tt::ttnn::LinearOp srcOp, tt::ttnn::LinearOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), - rewriter.getIndexAttr(2), - tt::ttnn_to_emitc::utils::convertBoolAttr( - rewriter, linearOp.getTransposeAAttr()), - tt::ttnn_to_emitc::utils::convertBoolAttr( - rewriter, linearOp.getTransposeBAttr()), - /*memory_config=*/tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*dtype=*/tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*program_config=*/ - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*activation=*/tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*compute_kernel_config=*/ - ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*core_grid=*/ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*output_tile=*/ttnn_to_emitc::utils::createStdNullopt(rewriter)}); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - rewriter.replaceOpWithNewOp( - linearOp, this->getTypeConverter()->convertType(linearOp.getType()), - this->convertOpName(linearOp), arrayAttrs, nullptr, - adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getA()), + emitter.emit(srcOp.getB()), + emitter.emit(srcOp.getBias()), + emitter.emit(srcOp.getTransposeA()), + emitter.emit(srcOp.getTransposeB()), + /*memory_config=*/emitter.emit(std::nullopt), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -307,35 +289,23 @@ class MatmulOpConversionPattern tt::ttnn::MatmulOp>::TTNNToEmitCBaseOpConversionPattern; LogicalResult - matchAndRewrite(tt::ttnn::MatmulOp matmulOp, - tt::ttnn::MatmulOp::Adaptor adaptor, + matchAndRewrite(tt::ttnn::MatmulOp srcOp, tt::ttnn::MatmulOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); + // ANCHOR: adding_an_op_matmul_ttnn_to_emitc_array_attrs - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {rewriter.getIndexAttr(0), rewriter.getIndexAttr(1), - tt::ttnn_to_emitc::utils::convertBoolAttr( - rewriter, matmulOp.getTransposeAAttr()), - tt::ttnn_to_emitc::utils::convertBoolAttr( - rewriter, matmulOp.getTransposeBAttr()), - /*memory_config=*/tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*dtype=*/tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*program_config=*/ - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*activation=*/tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*compute_kernel_config=*/ - ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*core_grid=*/ttnn_to_emitc::utils::createStdNullopt(rewriter), - /*output_tile=*/ttnn_to_emitc::utils::createStdNullopt(rewriter)}); + llvm::SmallVector args{ + emitter.emit(srcOp.getA()), + emitter.emit(srcOp.getB()), + emitter.emit(srcOp.getTransposeA()), + emitter.emit(srcOp.getTransposeB()), + /*memory_config=*/emitter.emit(std::nullopt), + }; // ANCHOR_END: adding_an_op_matmul_ttnn_to_emitc_array_attrs - rewriter.replaceOpWithNewOp( - matmulOp, this->getTypeConverter()->convertType(matmulOp.getType()), - this->convertOpName(matmulOp), arrayAttrs, nullptr, - adaptor.getOperands()); + emitter.replaceOp(*this, args); return success(); } @@ -353,22 +323,20 @@ class SoftmaxOpConversionPattern tt::ttnn::SoftmaxOp>::TTNNToEmitCBaseOpConversionPattern; LogicalResult - matchAndRewrite(tt::ttnn::SoftmaxOp softmaxOp, + matchAndRewrite(tt::ttnn::SoftmaxOp srcOp, tt::ttnn::SoftmaxOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr({ - mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - softmaxOp.getDimensionAttr(), - }); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - rewriter.replaceOpWithNewOp( - softmaxOp, this->getTypeConverter()->convertType(softmaxOp.getType()), - this->convertOpName(softmaxOp), arrayAttrs, nullptr, - adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDimension()), + /*memory_config=*/emitter.emit(std::nullopt), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -388,19 +356,15 @@ class EmbeddingOpConversionPattern tt::ttnn::EmbeddingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr({ - mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - mlir::IntegerAttr::get(rewriter.getIndexType(), 1), - }); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + embeddingOp, adaptor, rewriter); - rewriter.replaceOpWithNewOp( - embeddingOp, - this->getTypeConverter()->convertType(embeddingOp.getType()), - this->convertOpName(embeddingOp), arrayAttrs, nullptr, - adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(embeddingOp.getInput()), + emitter.emit(embeddingOp.getWeight()), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -420,21 +384,18 @@ class MorehCumSumOpConversionPattern tt::ttnn::MorehCumSumOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr({ - mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - srcOp.getDimAttr(), - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - }); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDim()), + /*output=*/emitter.emit(std::nullopt), + emitter.emit(srcOp.getMemoryConfig()), + /*compute_kernel_config=*/emitter.emit(std::nullopt), + }; + emitter.replaceOp(*this, args); return success(); } }; @@ -459,6 +420,7 @@ class MeanOpConversionPattern emitter.emit(srcOp.getInput()), emitter.emit<::ttnn::SmallVector>(srcOp.getDimArg()), emitter.emit(srcOp.getKeepDim()), + /*memory_config=*/emitter.emit(std::nullopt), }; emitter.replaceOp(*this, args); @@ -480,20 +442,17 @@ class ArgMaxOpConversionPattern matchAndRewrite(tt::ttnn::ArgMaxOp srcOp, tt::ttnn::ArgMaxOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr({ - rewriter.getIndexAttr(0), - srcOp.getDimAttr(), - tt::ttnn_to_emitc::utils::convertBoolAttr(rewriter, - srcOp.getUseMulticoreAttr()), - tt::ttnn_to_emitc::utils::createStdNullopt(rewriter), - }); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDim()), + emitter.emit(srcOp.getUseMulticore()), + emitter.emit(srcOp.getMemoryConfig()), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -518,7 +477,8 @@ class ReshapeOpConversionPattern llvm::SmallVector args{ emitter.emit(srcOp.getInput()), emitter.emit>(srcOp.getShape()), - emitter.emit(srcOp.getMemoryConfig())}; + emitter.emit(srcOp.getMemoryConfig()), + }; emitter.replaceOp(*this, args); @@ -539,15 +499,17 @@ class TransposeOpConversionPattern matchAndRewrite(tt::ttnn::TransposeOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so - // an ArrayAttr object holding IndexTypes is created to denote this. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {rewriter.getIndexAttr(0), srcOp.getDim0Attr(), srcOp.getDim1Attr()}); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDim0()), + emitter.emit(srcOp.getDim1()), + /*memory_config=*/emitter.emit(std::nullopt), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -566,47 +528,16 @@ class ConcatOpConversionPattern matchAndRewrite(tt::ttnn::ConcatOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // tt::ttnn::concat op requires a `std::vector<>` of `Tensor` objects, but - // we can't really create a `std::vector<>` with `Value` objects without - // introducing an EmitC op that takes in these `Value` objects. We do this - // by creating a utility function within the IR that converts a list of - // `Tensor` objects into a `std::vector`. - - tt::ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, srcOp); - - // TODO (azecevic): Investigate if this op is the special case that needs to - // use this fallback, or if it can be handled in a more general way with - // TTNNToEmitCEmitter. - mlir::emitc::CallOpaqueOp vectorOp = rewriter.create( - srcOp.getLoc(), - emitc::OpaqueType::get(rewriter.getContext(), - "std::vector"), - tt::ttnn_to_emitc::utils::kCreateVectorFunctionName, nullptr, nullptr, - adaptor.getInputs()); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - // Create operands vector - // - llvm::SmallVector operands{ - vectorOp->getResult(0), // Input vector of tensors + llvm::SmallVector args{ + emitter.emit(srcOp.getInputs()), + emitter.emit(srcOp.getDim()), + /*memory_config=*/emitter.emit(std::nullopt), }; - ArrayAttr arrayAttrs = rewriter.getArrayAttr({ - mlir::IntegerAttr::get(rewriter.getIndexType(), - 0), // Input vector of tensors - srcOp.getDimAttr(), // Concat dimension - srcOp.getMemoryConfig() - ? (operands.append( - 1, ttnn_to_emitc::utils::createMemoryConfigOp( - rewriter, srcOp.getMemoryConfigAttr(), srcOp.getLoc()) - ->getResult(0)), - mlir::cast(rewriter.getIndexAttr(1))) - : ttnn_to_emitc::utils::createStdNullopt( - rewriter) // ttnn::MemoryConfig - }); - - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, operands); + emitter.replaceOp(*this, args); return success(); } @@ -621,16 +552,15 @@ class RepeatOpConversionPattern tt::ttnn::RepeatOp>::TTNNToEmitCBaseOpConversionPattern; LogicalResult - matchAndRewrite(tt::ttnn::RepeatOp repeatOp, - tt::ttnn::RepeatOp::Adaptor adaptor, + matchAndRewrite(tt::ttnn::RepeatOp srcOp, tt::ttnn::RepeatOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ttnn_to_emitc::EmitCTTNNEmitter emitter( - repeatOp, adaptor, rewriter); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); llvm::SmallVector args{ - emitter.emit(repeatOp.getInput()), - emitter.emit(repeatOp.getRepeatDims()), + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getRepeatDims()), }; emitter.replaceOp(*this, args); @@ -690,9 +620,10 @@ class GetDeviceOpConversionPattern matchAndRewrite(tt::ttnn::GetDeviceOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), nullptr, nullptr, adaptor.getOperands()); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); + + emitter.replaceOp(*this, {}); return success(); } @@ -713,43 +644,16 @@ class ToDeviceOpConversionPattern matchAndRewrite(tt::ttnn::ToDeviceOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector attrs; - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0)); - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1)); - llvm::SmallVector operands(adaptor.getOperands()); - - if (srcOp.getMemoryConfig()) { - // Create ArrayAttr object holding MemoryConfig attributes. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {tt::ttnn_to_emitc::utils::convertTensorMemoryLayout( - rewriter, srcOp.getMemoryConfig()->getTensorMemoryLayout()), - tt::ttnn_to_emitc::utils::convertBufferType( - rewriter, srcOp.getMemoryConfig()->getBufferType())}); - - // Create MemoryConfig object first, then pass it to the op. - // - emitc::CallOpaqueOp memCfgOp = rewriter.create( - srcOp->getLoc(), - emitc::OpaqueType::get(rewriter.getContext(), "ttnn::MemoryConfig"), - "ttnn::MemoryConfig", arrayAttrs, nullptr, ValueRange()); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - // Concat operands and MemoryConfig object. - // - operands.append(1, memCfgOp.getResult(0)); - - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 2)); - } else { - attrs.push_back(tt::ttnn_to_emitc::utils::createStdNullopt(rewriter)); - } - - ArrayAttr finalAttrs = ArrayAttr::get(srcOp->getContext(), attrs); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDevice()), + emitter.emit(srcOp.getMemoryConfig()), + }; - // Convert ToDeviceOp - // - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), finalAttrs, nullptr, operands); + emitter.replaceOp(*this, args); return success(); } @@ -770,9 +674,14 @@ class FromDeviceOpConversionPattern matchAndRewrite(tt::ttnn::FromDeviceOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), nullptr, nullptr, adaptor.getOperands()); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); + + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -793,14 +702,16 @@ class TypecastOpConversionPattern matchAndRewrite(tt::ttnn::TypecastOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - tt::ttnn_to_emitc::utils::convertDType(rewriter, - srcOp.getDtypeAttr())}); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDtype()), + /*memory_config=*/emitter.emit(std::nullopt), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -821,14 +732,15 @@ class ToDTypeOpConversionPattern matchAndRewrite(tt::ttnn::ToDTypeOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {mlir::IntegerAttr::get(rewriter.getIndexType(), 0), - tt::ttnn_to_emitc::utils::convertDType(rewriter, - srcOp.getDtypeAttr())}); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getDtype()), + }; + + emitter.replaceOp(*this, args); return success(); } @@ -848,36 +760,16 @@ class ToMemoryConfigOpConversionPattern LogicalResult matchAndRewrite(tt::ttnn::ToMemoryConfigOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Create ArrayAttr object holding MemoryConfig attributes. - // - ArrayAttr arrayAttrs = rewriter.getArrayAttr( - {tt::ttnn_to_emitc::utils::convertTensorMemoryLayout( - rewriter, srcOp.getMemoryConfig().getTensorMemoryLayout()), - tt::ttnn_to_emitc::utils::convertBufferType( - rewriter, srcOp.getMemoryConfig().getBufferType())}); - - // Create MemoryConfig object first, then pass it to the op. - // - emitc::CallOpaqueOp memCfgOp = rewriter.create( - srcOp->getLoc(), - emitc::OpaqueType::get(rewriter.getContext(), "ttnn::MemoryConfig"), - "ttnn::MemoryConfig", arrayAttrs, nullptr, ValueRange()); - - // Concat operands and MemoryConfig object. - // - llvm::SmallVector operands(adaptor.getOperands()); - operands.append(1, memCfgOp.getResult(0)); - llvm::SmallVector attrs; - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0)); - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1)); - attrs.push_back(tt::ttnn_to_emitc::utils::createStdNullopt(rewriter)); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - ArrayAttr finalAttrs = ArrayAttr::get(srcOp->getContext(), attrs); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getMemoryConfig()), + }; - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), finalAttrs, nullptr, operands); + emitter.replaceOp(*this, args); return success(); } @@ -898,19 +790,16 @@ class ToLayoutOpConversionPattern matchAndRewrite(tt::ttnn::ToLayoutOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector attrs; - attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0)); - attrs.push_back(tt::ttnn_to_emitc::utils::convertLayoutAttr( - rewriter, srcOp.getLayoutAttr())); - attrs.push_back(tt::ttnn_to_emitc::utils::createStdNullopt(rewriter)); - attrs.push_back(tt::ttnn_to_emitc::utils::createStdNullopt(rewriter)); - attrs.push_back(createNullDevicePointer(rewriter)); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), emitter.emit(srcOp.getLayout()), + emitter.emit(srcOp.getDtype()), emitter.emit(srcOp.getMemoryConfig()), + emitter.emit(srcOp.getDevice()) | + emitter.emit<::ttnn::IDevice>(nullptr)}; - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + emitter.replaceOp(*this, args); return success(); } @@ -960,73 +849,16 @@ class ZerosOpConversionPattern matchAndRewrite(tt::ttnn::ZerosOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // tt::ttnn:ZerosOp has 5 input params: - // - // let arguments = (ins TTNN_ShapeAttr:$shape, - // OptionalAttr:$dtype, - // OptionalAttr:$layout, - // Optional:$device, - // OptionalAttr:$memory_config); - // - // Some of them are Attrs, some are Values. ShapeAttr is required, while - // others are optional. Additionally, in the context of C++, some of the - // Attrs (like shape) need to be instantiated into objects before being - // passed to the op. Therefore: - // - // We first create a tt::ttnn::SimpleShape object (SSA) by calling - // createShapeOp() and add it to the operands vector, but also add an - // IndexAttr in ArrayAttr to reference it (this is an EmitC mechanism that - // allows for combining Attrs and Values when calling an OpaqueOp). All the - // other input params are optional, so we create them on-the-fly into the - // ArrayAttr, whether they are an actual Attr, or a Value pointed to by - // IndexAttr. If they are present, we create the object and pass it to the - // op. If not, we pass std::nullopt. - - // Create tt::ttnn::SimpleShape() call - // - emitc::CallOpaqueOp shapeOp = tt::ttnn_to_emitc::utils::createShapeOp( - rewriter, srcOp.getShapeAttr(), srcOp.getLoc()); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - llvm::SmallVector operands{ - shapeOp->getResult(0), + llvm::SmallVector args{ + emitter.emit(srcOp.getShape()), emitter.emit(srcOp.getDtype()), + emitter.emit(srcOp.getLayout()), emitter.emit(srcOp.getDevice()), + emitter.emit(srcOp.getMemoryConfig()), }; - // Create ArrayAttr object holding attributes and pointers to operands - // - // Params that are Values are added to the operands vector on-the-fly, and - // a corresponding IndexAttr is added to the ArrayAttr to reference them. - // - size_t operandIndex = 0; - ArrayAttr arrayAttr = rewriter.getArrayAttr({ - rewriter.getIndexAttr(operandIndex++), // tt::ttnn::SimpleShape - srcOp.getDtype().has_value() - ? tt::ttnn_to_emitc::utils::convertDType(rewriter, - srcOp.getDtypeAttr()) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::DataType - srcOp.getLayout().has_value() - ? tt::ttnn_to_emitc::utils::convertLayoutAttr(rewriter, - srcOp.getLayoutAttr()) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::Layout - adaptor.getDevice() - ? (operands.append(1, adaptor.getDevice()), - mlir::cast(rewriter.getIndexAttr(operandIndex++))) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::Device - srcOp.getMemoryConfig().has_value() - ? (operands.append( - 1, tt::ttnn_to_emitc::utils::createMemoryConfigOp( - rewriter, srcOp.getMemoryConfigAttr(), srcOp.getLoc()) - ->getResult(0)), - mlir::cast(rewriter.getIndexAttr(operandIndex++))) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::MemoryConfig - }); - - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttr, nullptr, operands); + emitter.replaceOp(*this, args); return success(); } @@ -1046,73 +878,16 @@ class OnesOpConversionPattern matchAndRewrite(tt::ttnn::OnesOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // tt::ttnn:OnesOp has 5 input params: - // - // let arguments = (ins TTNN_ShapeAttr:$shape, - // OptionalAttr:$dtype, - // OptionalAttr:$layout, - // Optional:$device, - // OptionalAttr:$memory_config); - // - // Some of them are Attrs, some are Values. ShapeAttr is required, while - // others are optional. Additionally, in the context of C++, some of the - // Attrs (like shape) need to be instantiated into objects before being - // passed to the op. Therefore: - // - // We first create a tt::ttnn::Shape object (SA) by calling - // createShapeOp() and add it to the operands vector, but also add an - // IndexAttr in ArrayAttr to reference it (this is an EmitC mechanism that - // allows for combining Attrs and Values when calling an OpaqueOp). All the - // other input params are optional, so we create them on-the-fly into the - // ArrayAttr, whether they are an actual Attr, or a Value pointed to by - // IndexAttr. If they are present, we create the object and pass it to the - // op. If not, we pass std::nullopt. - - // Create tt::ttnn::Shape() call - // - emitc::CallOpaqueOp shapeOp = tt::ttnn_to_emitc::utils::createShapeOp( - rewriter, srcOp.getShapeAttr(), srcOp.getLoc()); + ttnn_to_emitc::EmitCTTNNEmitter emitter(srcOp, adaptor, + rewriter); - llvm::SmallVector operands{ - shapeOp->getResult(0), + llvm::SmallVector args{ + emitter.emit(srcOp.getShape()), emitter.emit(srcOp.getDtype()), + emitter.emit(srcOp.getLayout()), emitter.emit(srcOp.getDevice()), + emitter.emit(srcOp.getMemoryConfig()), }; - // Create ArrayAttr object holding attributes and pointers to operands - // - // Params that are Values are added to the operands vector on-the-fly, and - // a corresponding IndexAttr is added to the ArrayAttr to reference them. - // - size_t operandIndex = 0; - ArrayAttr arrayAttr = rewriter.getArrayAttr({ - rewriter.getIndexAttr(operandIndex++), // tt::ttnn::Shape - srcOp.getDtype().has_value() - ? tt::ttnn_to_emitc::utils::convertDType(rewriter, - srcOp.getDtypeAttr()) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::DataType - srcOp.getLayout().has_value() - ? tt::ttnn_to_emitc::utils::convertLayoutAttr(rewriter, - srcOp.getLayoutAttr()) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::Layout - adaptor.getDevice() - ? (operands.append(1, adaptor.getDevice()), - mlir::cast(rewriter.getIndexAttr(operandIndex++))) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::Device - srcOp.getMemoryConfig().has_value() - ? (operands.append( - 1, tt::ttnn_to_emitc::utils::createMemoryConfigOp( - rewriter, srcOp.getMemoryConfigAttr(), srcOp.getLoc()) - ->getResult(0)), - mlir::cast(rewriter.getIndexAttr(operandIndex++))) - : tt::ttnn_to_emitc::utils::createStdNullopt( - rewriter), // tt::ttnn::MemoryConfig - }); - - rewriter.replaceOpWithNewOp( - srcOp, this->getTypeConverter()->convertType(srcOp.getType()), - this->convertOpName(srcOp), arrayAttr, nullptr, operands); + emitter.replaceOp(*this, args); return success(); } @@ -1133,15 +908,15 @@ class DeallocateOpConversionPattern matchAndRewrite(tt::ttnn::DeallocateOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ArrayAttr arrayAttr = rewriter.getArrayAttr({ - rewriter.getIndexAttr(0), - tt::ttnn_to_emitc::utils::convertBoolAttr(rewriter, - srcOp.getForceAttr()), - }); + ttnn_to_emitc::EmitCTTNNEmitter emitter( + srcOp, adaptor, rewriter); - rewriter.replaceOpWithNewOp( - srcOp, srcOp->getResultTypes(), this->convertOpName(srcOp), arrayAttr, - nullptr, adaptor.getOperands()); + llvm::SmallVector args{ + emitter.emit(srcOp.getInput()), + emitter.emit(srcOp.getForce()), + }; + + emitter.replaceOp(*this, args); return success(); }