Skip to content

Commit

Permalink
Add constraints and runtime API to transpose (#2322)
Browse files Browse the repository at this point in the history
### Ticket
#2314 

### Problem description
The optimizer needs more ops with constraints and runtime support to be
able to ingest real models. Transpose is useful for both rennet (#2277)
and llama (#2084)

### What's changed
Added constraints and runtime APIs to the `transposeOp`. Added unit
tests to exercise the new API

Closes #2314 

### Checklist
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
arminaleTT authored Mar 4, 2025
1 parent de45959 commit be0809e
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 14 deletions.
4 changes: 3 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,9 @@ def TTNN_SoftmaxOp : TTNN_Op<"softmax",
let hasVerifier = 1;
}

def TTNN_TransposeOp : TTNN_Op<"transpose"> {
def TTNN_TransposeOp : TTNN_Op<"transpose",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let summary = "Transpose op.";
let description = [{
Transpose tensor along two given dimensions.
Expand Down
17 changes: 17 additions & 0 deletions include/ttmlir/OpModel/TTNN/TTNNOpModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,23 @@ getOpRuntime(llvm::ArrayRef<int64_t> inputShape,

}; // namespace ReshapeOpInterface

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//

namespace TransposeOpInterface {
llvm::Expected<std::tuple<size_t, size_t, size_t>>
getOpConstraints(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0,
const int dim1, mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

llvm::Expected<size_t>
getOpRuntime(llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0,
const int dim1, mlir::tt::ttnn::TTNNLayoutAttr outputLayout);

}; // namespace TransposeOpInterface

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
49 changes: 37 additions & 12 deletions lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ SoftmaxOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto inputShape = getInput().getType().getShape();

const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();
Expand All @@ -155,8 +154,7 @@ SoftmaxOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto inputShape = getInput().getType().getShape();

const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();
Expand All @@ -174,8 +172,7 @@ MeanOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto inputShape = getInput().getType().getShape();

llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
Expand All @@ -192,8 +189,7 @@ MeanOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto inputShape = getInput().getType().getShape();

return op_model::ttnn::MeanOpInterface::getOpRuntime(
inputShape, inputs[0], detail::convertReductionArg(getDimArg()),
Expand All @@ -209,8 +205,7 @@ ReshapeOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto inputShape = getInput().getType().getShape();

const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();
Expand All @@ -229,15 +224,45 @@ ReshapeOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape =
mlir::cast<RankedTensorType>(getOperand().getType()).getShape();
const auto inputShape = getInput().getType().getShape();
const auto outputShape =
mlir::cast<RankedTensorType>(getResult().getType()).getShape();

return op_model::ttnn::ReshapeOpInterface::getOpRuntime(inputShape, inputs[0],
outputShape, output);
}

//===----------------------------------------------------------------------===//
// TransposeOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//

llvm::Expected<std::tuple<size_t, size_t, size_t>>
TransposeOp::getOpConstraints(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape = getInput().getType().getShape();

llvm::Expected<bool> check = detail::checkDeviceWorkerGrid(getOperation());
if (!check) {
return check.takeError();
}

return op_model::ttnn::TransposeOpInterface::getOpConstraints(
inputShape, inputs[0], getDim0(), getDim1(), output);
}

llvm::Expected<size_t>
TransposeOp::getOpRuntime(const std::vector<TTNNLayoutAttr> &inputs,
const TTNNLayoutAttr &output) {
assert(inputs.size() == 1);

const auto inputShape = getInput().getType().getShape();

return op_model::ttnn::TransposeOpInterface::getOpRuntime(
inputShape, inputs[0], getDim0(), getDim1(), output);
}

//===----------------------------------------------------------------------===//
// MatmulOp - TTNN Op Model Interface
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/OpModel/TTNN/MetalHeaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include "ttnn/graph/graph_query_op_runtime.hpp"
#include "ttnn/graph/graph_trace_utils.hpp"
#include "ttnn/operations/data_movement/reshape_view/reshape.hpp"
#include "ttnn/operations/data_movement/transpose/transpose.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
Expand Down
65 changes: 65 additions & 0 deletions lib/OpModel/TTNN/TTNNOpModelLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,71 @@ ReshapeOpInterface::getOpRuntime(llvm::ArrayRef<int64_t> inputShape,
#endif // TTMLIR_ENABLE_OPMODEL
}

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
llvm::Expected<std::tuple<size_t, size_t, size_t>>
TransposeOpInterface::getOpConstraints(
llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
#ifdef TTMLIR_ENABLE_OPMODEL
auto transposeOpQuery = [](llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
// open device device, will close it at the end of function
::tt::tt_metal::v0::IDevice *device =
SingletonDeviceContext::getInstance().getDevice();

// prepare io specs
const auto [inputSpec] = detail::convertToTensorSpec(
device, std::make_tuple(inputShape, inputLayout));

// run op constraint query
return ::ttnn::graph::query_op_constraints(
::ttnn::transpose, device, inputSpec, dim0, dim1,
conversion::getMemoryConfig(outputLayout));
};

return operation::getOpConstraints("TransposeOpInterface", transposeOpQuery,
inputShape, inputLayout, dim0, dim1,
outputLayout);
#else
return std::make_tuple(0, 0, 0);
#endif // TTMLIR_ENABLE_OPMODEL
}

llvm::Expected<size_t> TransposeOpInterface::getOpRuntime(
llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout, const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
#ifdef TTMLIR_ENABLE_OPMODEL
auto transposeOpQuery = [](llvm::ArrayRef<int64_t> inputShape,
mlir::tt::ttnn::TTNNLayoutAttr inputLayout,
const int dim0, const int dim1,
mlir::tt::ttnn::TTNNLayoutAttr outputLayout) {
// open device device, will close it at the end of function
::tt::tt_metal::v0::IDevice *device =
SingletonDeviceContext::getInstance().getDevice();

// prepare io specs
const auto [inputSpec] = detail::convertToTensorSpec(
device, std::make_tuple(inputShape, inputLayout));

return ::ttnn::graph::query_op_runtime(
::ttnn::transpose, device, inputSpec, dim0, dim1,
conversion::getMemoryConfig(outputLayout));
};

return operation::getOpRuntime("TransposeOpInterface", transposeOpQuery,
inputShape, inputLayout, dim0, dim1,
outputLayout);
#else
return llvm::createStringError("Not Implemented");
#endif // TTMLIR_ENABLE_OPMODEL
}

//===----------------------------------------------------------------------===//
// MatmulOp
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions test/unittests/OpModel/TTNN/Lib/TestOpModelLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,59 @@ TEST_F(OpModelTest, Reshape) {
EXPECT_TRUE(runtimeExp.get() > 0);
}

TEST_F(OpModelTest, Transpose) {
const llvm::SmallVector<int64_t> tensorShape = {workerCoresN300, 1024};
const auto workerGrid = CreateWorkerGrid(gridShapeHwN300);
const mlir::tt::ttnn::TTNNLayoutAttr layoutDRAM =
CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::DRAM,
mlir::tt::ttnn::TensorMemoryLayout::Interleaved);
const mlir::tt::ttnn::TTNNLayoutAttr layoutL1Interleaved =
CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1,
mlir::tt::ttnn::TensorMemoryLayout::Interleaved);
const mlir::tt::ttnn::TTNNLayoutAttr layoutL1WSharded =
CreateTiledLayout(tensorShape, mlir::tt::ttnn::BufferType::L1,
mlir::tt::ttnn::TensorMemoryLayout::WidthSharded);

auto legalExp = Device::getDeviceConstraints(workerGrid);
EXPECT_TRUE(static_cast<bool>(legalExp));

auto constraintsExp = TransposeOpInterface::getOpConstraints(
tensorShape, layoutDRAM, 0, 1, layoutDRAM);
EXPECT_TRUE(static_cast<bool>(constraintsExp));
auto [cb_size, peak_size, output_size] = constraintsExp.get();
EXPECT_EQ(cb_size, 8192);
EXPECT_EQ(output_size, 0);
EXPECT_EQ(peak_size, 0);

auto runtimeExp = TransposeOpInterface::getOpRuntime(tensorShape, layoutDRAM,
0, 1, layoutDRAM);
EXPECT_TRUE(static_cast<bool>(runtimeExp));
EXPECT_TRUE(runtimeExp.get() > 0);

constraintsExp = TransposeOpInterface::getOpConstraints(
tensorShape, layoutDRAM, 0, 1, layoutL1Interleaved);
EXPECT_TRUE(static_cast<bool>(constraintsExp));
std::tie(cb_size, peak_size, output_size) = constraintsExp.get();
EXPECT_EQ(cb_size, 8192);
EXPECT_EQ(output_size, 2048);
EXPECT_EQ(peak_size, 2048);

runtimeExp = TransposeOpInterface::getOpRuntime(tensorShape, layoutDRAM, 0, 1,
layoutL1Interleaved);
EXPECT_TRUE(static_cast<bool>(runtimeExp));
EXPECT_TRUE(runtimeExp.get() > 0);

constraintsExp = TransposeOpInterface::getOpConstraints(
tensorShape, layoutL1Interleaved, 0, 1, layoutL1WSharded);
EXPECT_TRUE(!static_cast<bool>(constraintsExp));
llvm::consumeError(constraintsExp.takeError());

runtimeExp = TransposeOpInterface::getOpRuntime(
tensorShape, layoutL1Interleaved, 0, 1, layoutL1WSharded);
EXPECT_TRUE(!static_cast<bool>(runtimeExp));
llvm::consumeError(runtimeExp.takeError());
}

TEST_F(OpModelTest, SoftmaxSharded) {
const llvm::SmallVector<int64_t> tensorShape = {16 * workerCoresN300 * 32,
32};
Expand Down
35 changes: 34 additions & 1 deletion test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ TEST_F(OpModelBase, reshapeOp) {
reshape.setShapeAttr(builder.getArrayAttr(llvm::SmallVector<mlir::Attribute>{
builder.getI64IntegerAttr(64 * 4), builder.getI64IntegerAttr(1024 / 4)}));

// test mean Op interface
// test reshape Op interface
auto constraintsExp = getOpConstraints(reshape.getOperation());
if (constraintsExp) {
auto l1 = constraintsExp.get();
Expand All @@ -292,4 +292,37 @@ TEST_F(OpModelBase, reshapeOp) {
}
}

TEST_F(OpModelBase, transposeOp) {
// create TransposeOp
llvm::SmallVector<int64_t> tensorShapeA = {64, 1024};
llvm::SmallVector<int64_t> tensorShapeO = {1024, 64};

auto input = createEmptyTensor(tensorShapeA);
auto output = createEmptyTensor(tensorShapeO);

auto transpose = builder.create<TransposeOp>(builder.getUnknownLoc(),
output.getType(), input, 0, 1);
transpose->setAttr(DeviceAttr::name, getFakeDeviceAttr());

// test transpose Op interface
auto constraintsExp = getOpConstraints(transpose.getOperation());
if (constraintsExp) {
auto l1 = constraintsExp.get();
const auto &[cb_size, peak_size, output_size] = l1;
EXPECT_EQ(cb_size, 8192);
EXPECT_EQ(peak_size, 2048);
EXPECT_EQ(output_size, 2048);
} else {
FAIL() << "Missing L1 constraints; Error="
<< llvm::toString(constraintsExp.takeError()) << std::endl;
}

auto runtimeExp = getOpRuntime(transpose.getOperation());
if (runtimeExp) {
EXPECT_TRUE(runtimeExp.get() > 0);
} else {
FAIL() << llvm::toString(runtimeExp.takeError());
}
}

} // namespace mlir::tt::ttnn

0 comments on commit be0809e

Please sign in to comment.