Skip to content

Commit

Permalink
TTNN->EmitC transition existing ops to new converter (#2345)
Browse files Browse the repository at this point in the history
### Ticket
#2343

### Problem description
Part of the effort to onboard all TTNN ops to conversion infrastructure
introduced in #2331

### What's changed
Transitioned ops with existing conversion to conversion with the new
converter. Few changes in the Emitter class, mainly in the way how
operands are treated, to accommodate more complex cases like `Variadic`
of tensors that maps to `std::vector<ttnn::Tensor>`.
  • Loading branch information
azecevicTT authored and odjuricicTT committed Mar 8, 2025
1 parent c4fd52d commit 017cf6d
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 440 deletions.
114 changes: 90 additions & 24 deletions include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef TTMLIR_CONVERSION_TTNNTOEMITC_EMITCCONVERSION_H
#define TTMLIR_CONVERSION_TTNNTOEMITC_EMITCCONVERSION_H

#include "ttmlir/Conversion/TTNNToEmitC/Utils.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

Expand All @@ -30,6 +31,10 @@ template <typename T>
struct SmallVector {
using value_type = T;
};

struct IDevice;

struct Tensor;
} // namespace ttnn

namespace mlir {
Expand Down Expand Up @@ -106,6 +111,16 @@ struct TypeName<::ttnn::SmallVector<T>> {
"::ttnn::SmallVector<" + TypeNameV<T> + ">";
};

template <>
struct TypeName<::ttnn::IDevice> {
inline static const std::string value = "::ttnn::IDevice";
};

template <>
struct TypeName<::ttnn::Tensor> {
inline static const std::string value = "::ttnn::Tensor";
};

template <typename T, typename Enable = void>
struct EmitCTypeConverter;

Expand Down Expand Up @@ -570,6 +585,15 @@ inline std::string convert(ttnn::MemoryConfigAttr attr) {
return buf;
}

template <typename T>
struct IsMLIRType {
static constexpr bool value = std::is_convertible_v<T, mlir::Attribute> ||
std::is_convertible_v<T, mlir::Value>;
};

template <typename T>
static constexpr bool IsMLIRTypeV = IsMLIRType<T>::value;

template <typename TTNNOp>
class EmitCTTNNEmitter {
public:
Expand Down Expand Up @@ -610,7 +634,7 @@ class EmitCTTNNEmitter {
template <typename TargetTy = void, typename SourceTy>
mlir::Attribute emit(std::optional<SourceTy> attr) {
if (!attr) {
return rewriter.getType<emitc::OpaqueAttr>(TypeNameV<std::nullopt_t>);
return emit(std::nullopt);
}

if constexpr (std::is_void_v<TargetTy>) {
Expand All @@ -624,20 +648,47 @@ class EmitCTTNNEmitter {
return rewriter.getType<emitc::OpaqueAttr>(TypeNameV<std::nullopt_t>);
}

mlir::Attribute emit(Value val) {
mlir::Attribute emit(mlir::Value val) {
if (!val) {
return rewriter.getType<emitc::OpaqueAttr>(TypeNameV<std::nullopt_t>);
return emit(std::nullopt);
}

auto operand =
llvm::find_if(op->getOpOperands(), [&](OpOperand &opOperand) {
return opOperand.get() == val;
});
if (operand == op->getOpOperands().end()) {
llvm_unreachable("Unknown operand");
mlir::OpOperand *opOperand =
std::find_if(std::next(op->getOpOperands().begin(), operands.size()),
op->getOpOperands().end(),
[&](OpOperand &operand) { return operand.get() == val; });

unsigned index = opOperand->getOperandNumber();
operands.push_back(adaptor.getOperands()[index]);
return rewriter.getIndexAttr(index);
}

mlir::Attribute emit(mlir::Operation::operand_range operands) {
for (mlir::OpOperand &opOperand : op->getOpOperands()) {
auto begin =
std::next(op->getOperands().begin(), opOperand.getOperandNumber());
if (mlir::Operation::operand_range(
begin, std::next(begin, operands.size())) != operands) {
continue;
}
unsigned index = opOperand.getOperandNumber();
llvm::SmallVector<mlir::Value> values(
adaptor.getOperands().begin() + index,
adaptor.getOperands().begin() + index + operands.size());
this->operands.push_back(createVector(values));
return rewriter.getIndexAttr(index);
}
llvm_unreachable("Invalid operand range");
}

return rewriter.getIndexAttr(operand->getOperandNumber());
template <typename TargetTy = void>
mlir::Attribute emit(std::nullptr_t) {
if constexpr (std::is_void_v<TargetTy>) {
return rewriter.getType<emitc::OpaqueAttr>("nullptr");
} else {
return rewriter.getType<emitc::OpaqueAttr>(
"static_cast<" + TypeNameV<TargetTy> + " *>(nullptr)");
}
}

// Handles the case when source type is convertible to mlir::Attribute type
Expand All @@ -651,7 +702,7 @@ class EmitCTTNNEmitter {
if (auto convertedValue = EmitCTypeConverter<TargetTy>::convert(attr)) {
return rewriter.getType<emitc::OpaqueAttr>(*convertedValue);
}
return rewriter.getType<emitc::OpaqueAttr>(TypeNameV<std::nullopt_t>);
return emit(std::nullopt);
}

// Handles the case when source type is a non mlir::Attribute convertible type
Expand All @@ -661,10 +712,9 @@ class EmitCTTNNEmitter {
// appropriate C++ type.
// TODO (azecevic): See if we can simplify the condition for this overload
// instantiation.
template <typename SourceTy, typename TargetTy = SourceTy>
std::enable_if_t<!std::is_convertible_v<SourceTy, mlir::Attribute> &&
!std::is_convertible_v<SourceTy, mlir::Value>,
mlir::Attribute>
template <typename SourceTy, typename TargetTy = std::remove_reference_t<
std::remove_cv_t<SourceTy>>>
std::enable_if_t<!IsMLIRTypeV<SourceTy>, mlir::Attribute>
emit(SourceTy &&attr) {
auto result =
EmitCTypeConverter<TargetTy>::convert(std::forward<SourceTy>(attr));
Expand All @@ -681,35 +731,51 @@ class EmitCTTNNEmitter {
template <typename OpConversionPatternTy>
emitc::CallOpaqueOp replaceOp(OpConversionPatternTy &&opConversionPattern,
llvm::ArrayRef<mlir::Attribute> args) {
auto resultTypes = llvm::to_vector(
llvm::map_range(op->getResultTypes(), [&](Type type) -> Type {
return opConversionPattern.getTypeConverter()->convertType(type);
}));
return rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op,
opConversionPattern.getTypeConverter()->convertType(
op->getResult(0).getType()),
opConversionPattern.convertOpName(op), rewriter.getArrayAttr(args),
nullptr, adaptor.getOperands());
op, resultTypes, opConversionPattern.convertOpName(op),
rewriter.getArrayAttr(args), nullptr, operands);
}

private:
mlir::Value createVector(ValueRange operands) {
tt::ttnn_to_emitc::utils::insertVecCreateFnIfNotExists(rewriter, op);

return rewriter
.create<emitc::CallOpaqueOp>(
op.getLoc(),
emitc::OpaqueType::get(rewriter.getContext(),
TypeNameV<std::vector<::ttnn::Tensor>>),
tt::ttnn_to_emitc::utils::kCreateVectorFunctionName, nullptr,
nullptr, operands)
->getResult(0);
}

TTNNOp op;
OpAdaptor adaptor;
ConversionPatternRewriter &rewriter;
llvm::SmallVector<mlir::Value> operands;
};

} // namespace ttnn_to_emitc
} // namespace tt

// Helper function that serves as an alternative to the
// `emit<std::variant<...>>` member function of the `EmitCTTNNEmitter` class.
// For example, instead of calling `emit<std::variant<int32_t, float>>(attr)`,
// one can call `emit<int32_t>(attr) | emit<float>(attr)`.
inline mlir::Attribute operator|(mlir::Attribute lhs, mlir::Attribute rhs) {
static const mlir::Attribute nulloptAttr =
emitc::OpaqueAttr::get(lhs.getContext(), TypeNameV<std::nullopt_t>);
static const mlir::Attribute nulloptAttr = emitc::OpaqueAttr::get(
lhs.getContext(), tt::ttnn_to_emitc::TypeNameV<std::nullopt_t>);
if (!lhs || lhs == nulloptAttr) {
return rhs;
}
return lhs;
}

} // namespace ttnn_to_emitc
} // namespace tt
} // namespace mlir

#endif // TTMLIR_CONVERSION_TTNNTOEMITC_EMITCCONVERSION_H
Loading

0 comments on commit 017cf6d

Please sign in to comment.