Skip to content

Commit

Permalink
move getDpsOutputs() helper into tt::ttir and a dialect-specific Utils.h
Browse files Browse the repository at this point in the history
  • Loading branch information
vroubtsovTT committed Mar 5, 2025
1 parent ffdd627 commit 46612ac
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 70 deletions.
23 changes: 11 additions & 12 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h"
#include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h"

#include "ttmlir/Utils.h"

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "ttmlir/Dialect/TTIR/IR/Utils.h"

#include <mlir/Bytecode/BytecodeOpInterface.h>
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/DestinationStyleOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>

namespace mlir::tt::ttir {

Expand Down
12 changes: 6 additions & 6 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TTIR_DPSOp<string mnemonic, list<Trait> traits = []> :
TTIR_Op<mnemonic, !listconcat([TTIROpInterface, DestinationStyleOpInterface], traits)> {
let extraClassDeclaration = [{
// base implementation that will detect getOutputMutable() or getOutputsMutable()
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }
}];
}

Expand Down Expand Up @@ -112,7 +112,7 @@ def TTIR_GenericOp : TTIR_BufferizableOp<"generic", [AttrSizedOperandSegments, N
let hasVerifier = 1;

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }
}];
}

Expand All @@ -138,7 +138,7 @@ def TTIR_ToLayoutOp : TTIR_BufferizableOp<"to_layout"> {
let results = (outs Variadic<AnyRankedTensor>:$results);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }

struct CompoundComponents {
bool isLayoutChange = false;
Expand Down Expand Up @@ -781,7 +781,7 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }
void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

// Returns the indexing maps and iterator types for the reduction op.
Expand Down Expand Up @@ -1826,7 +1826,7 @@ class TTIR_GenericElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseUnaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }
void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) {
Expand Down Expand Up @@ -1856,7 +1856,7 @@ class TTIR_GenericElementwiseBinaryOp<string mnemonic, list<Trait> traits = []>
TTIR_ElementwiseBinaryOp<mnemonic, !listconcat([TTIR_GenericRegionOpInterface], traits)> {

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }
void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) {
Expand Down
59 changes: 59 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-FileCopyrightText: (c) 20245 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTIR_IR_UTILS_H
#define TTMLIR_DIALECT_TTIR_IR_UTILS_H

#include <mlir/IR/ValueRange.h>

#include <type_traits>

namespace mlir::tt::ttir {

// detect the presence of 'getOutputsMutable()' in 'Op':
template <typename Op, typename = void>
inline constexpr bool has_variadic_outputs = false;

template <typename Op>
inline constexpr bool has_variadic_outputs<
Op, std::void_t<decltype(std::declval<Op>().getOutputsMutable())>> = true;

namespace impl {

template <typename Op, typename = void>
struct getDpsOutputs {
static mlir::MutableOperandRange evaluate(Op *op) {
return op->getOutputMutable();
}
};

template <typename Op>
struct getDpsOutputs<Op, std::enable_if_t<has_variadic_outputs<Op>>> {
static mlir::MutableOperandRange evaluate(Op *op) {
return op->getOutputsMutable();
}
};

} // namespace impl

// A helper for simplifying DPS tablegen derivations with 'arguments' of any
// form in {AnyRankedTensor:$output, Variadic<AnyRankedTensor>:$outputs}.
//
// If a base tablegen 'class' adds this extra class declaration, derived 'def's
// don't need to overrride it just to switch from single to variadic type of
// '$outputs' (or vice versa):
// ...
// clang-format off
// let extraClassDeclaration = [{
// MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); }
// }]
// clang-format on
template <typename Op>
mlir::MutableOperandRange getDpsOutputs(Op *op) {
return impl::getDpsOutputs<Op>::evaluate(op);
}

} // namespace mlir::tt::ttir

#endif // TTMLIR_DIALECT_TTIR_IR_UTILS_H
62 changes: 10 additions & 52 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
#ifndef TTMLIR_UTILS_H
#define TTMLIR_UTILS_H

#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Error.h"
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringExtras.h>
#include <llvm/Support/Error.h>
#include <mlir-c/IR.h>
#include <mlir/CAPI/IR.h>
#include <mlir/Dialect/Tensor/IR/Tensor.h>
#include <mlir/IR/AffineMap.h>
#include <mlir/IR/BuiltinAttributes.h>

#include <cstdint>
#include <type_traits>

namespace ttmlir::utils {

template <typename T>
T alignUp(T ptr, T alignment) {
return (ptr + alignment - 1) & ~(alignment - 1);
Expand Down Expand Up @@ -459,49 +460,6 @@ OpTy replaceOpWithNewDPSOp(mlir::PatternRewriter &rewriter, mlir::Operation *op,
return newOp;
}

// detect the presence of 'getOutputsMutable()' in 'Op':
template <typename Op, typename = void>
inline constexpr bool has_variadic_outputs = false;

template <typename Op>
inline constexpr bool has_variadic_outputs<
Op, std::void_t<decltype(std::declval<Op>().getOutputsMutable())>> = true;

namespace impl {

template <typename Op, typename _ = void>
struct getDpsOutputs {
static mlir::MutableOperandRange evaluate(Op *op) {
return op->getOutputMutable();
}
};

template <typename Op>
struct getDpsOutputs<Op, std::enable_if_t<has_variadic_outputs<Op>>> {
static mlir::MutableOperandRange evaluate(Op *op) {
return op->getOutputsMutable();
}
};

} // namespace impl

// A helper for simplifying DPS tablegen derivations with 'arguments' of any
// form in {AnyRankedTensor:$output, Variadic<AnyRankedTensor>:$outputs}.
//
// If a base tablegen 'class' adds this extra class declaration, derived 'def's
// don't need to overrride it just to switch from single to variadic type of
// '$outputs' (or vice versa):
// ...
// clang-format off
// let extraClassDeclaration = [{
// MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); }
// }]
// clang-format on
template <typename Op>
mlir::MutableOperandRange getDpsOutputs(Op *op) {
return impl::getDpsOutputs<Op>::evaluate(op);
}

} // namespace ttmlir::utils

#endif
#endif // TTMLIR_UTILS_H

0 comments on commit 46612ac

Please sign in to comment.