Skip to content

Commit

Permalink
Adding memory config to a reshape op (#2275)
Browse files Browse the repository at this point in the history
### Ticket
The following GH issue contains all the ops that require memory_config:
#1637

Reshape op is one on the list.

### Problem description
TTNN Reshape op requires a memory config attribute.

### What's changed
I added the memory config trait to a reshape op and implemented the
memory config e2e.

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sdjordjevicTT authored Mar 3, 2025
1 parent 7cb48a2 commit 071ce6f
Show file tree
Hide file tree
Showing 18 changed files with 270 additions and 95 deletions.
7 changes: 4 additions & 3 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -953,16 +953,17 @@ def TTNN_ConcatOp : TTNN_Op<"concat", [HasMemoryConfigTrait]> {
let hasVerifier = 1;
}

def TTNN_ReshapeOp : TTNN_Op<"reshape",
[DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
def TTNN_ReshapeOp : TTNN_Op<"reshape", [HasMemoryConfigTrait,
DeclareOpInterfaceMethods<TTNN_OpModelInterface, ["getOpConstraints", "getOpRuntime"]>]
> {
let summary = "Reshape op.";
let description = [{
Reshape tensor.
}];

let arguments = (ins AnyRankedTensor:$input,
I32ArrayAttr:$shape);
I32ArrayAttr:$shape,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);

let results = (outs AnyRankedTensor:$result);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {
llvm::SmallVector<int32_t>(outputType.getShape()));

rewriter.replaceOpWithNewOp<mlir::tt::ttnn::ReshapeOp>(
srcOp, outputType, newReduceOp, shapeAttr);
srcOp, outputType, newReduceOp, shapeAttr, /* memory_config */ nullptr);
}

// Determine if the workaround is required.
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ table ReshapeOp {
in: tt.target.ttnn.TensorRef;
out: tt.target.ttnn.TensorRef;
shape: [int32];
memory_config: tt.target.ttnn.MemoryConfig;
}

table RepeatOp {
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ class ReshapeOpConversionPattern : public OpConversionPattern<ttir::ReshapeOp> {
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getShape());
adaptor.getInput(), adaptor.getShape(), /* memory_config */ nullptr);
return success();
}
};
Expand Down Expand Up @@ -731,7 +731,7 @@ class SqueezeOpConversionPattern : public OpConversionPattern<ttir::SqueezeOp> {
// Replace the SqueezeOp with a ReshapeOp
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), shapeAttr);
adaptor.getInput(), shapeAttr, /* memory_config */ nullptr);

return success();
}
Expand Down Expand Up @@ -854,7 +854,7 @@ class UnsqueezeOpConversionPattern
// Replace the UnsqueezeOp with a ReshapeOp
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), shapeAttr);
adaptor.getInput(), shapeAttr, /* memory_config */ nullptr);

return success();
}
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/TTIRToTTNN/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ ttnn::ReshapeOp generateReshape(mlir::TypedValue<mlir::RankedTensorType> input,
newShape, inputType.getElementType(), outputLayoutAttr);

llvm::SmallVector<int32_t> newShapeI32(newShape.begin(), newShape.end());
return rewriter.create<ttnn::ReshapeOp>(
input.getLoc(), outputType, input, rewriter.getI32ArrayAttr(newShapeI32));
return rewriter.create<ttnn::ReshapeOp>(input.getLoc(), outputType, input,
rewriter.getI32ArrayAttr(newShapeI32),
/* memory_config */ nullptr);
}

ttnn::ReshapeOp
Expand Down
27 changes: 21 additions & 6 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ class ReshapeOpConversionPattern
llvm::SmallVector<mlir::Attribute> args{
emitter.emit(srcOp.getInput()),
emitter.emit<std::vector<int32_t>>(srcOp.getShape()),
};
emitter.emit(srcOp.getMemoryConfig())};

emitter.replaceOp(*this, args);

Expand Down Expand Up @@ -585,14 +585,29 @@ class ConcatOpConversionPattern
tt::ttnn_to_emitc::utils::kCreateVectorFunctionName, nullptr, nullptr,
adaptor.getInputs());

ArrayAttr arrayAttrs = rewriter.getArrayAttr(
{mlir::IntegerAttr::get(rewriter.getIndexType(), 0),
srcOp.getDimAttr()});
// Create operands vector
//
llvm::SmallVector<Value, 2> operands{
vectorOp->getResult(0), // Input vector of tensors
};

ArrayAttr arrayAttrs = rewriter.getArrayAttr({
mlir::IntegerAttr::get(rewriter.getIndexType(),
0), // Input vector of tensors
srcOp.getDimAttr(), // Concat dimension
srcOp.getMemoryConfig()
? (operands.append(
1, ttnn_to_emitc::utils::createMemoryConfigOp(
rewriter, srcOp.getMemoryConfigAttr(), srcOp.getLoc())
->getResult(0)),
mlir::cast<Attribute>(rewriter.getIndexAttr(1)))
: ttnn_to_emitc::utils::createStdNullopt(
rewriter) // ttnn::MemoryConfig
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
srcOp, this->getTypeConverter()->convertType(srcOp.getType()),
this->convertOpName(srcOp), arrayAttrs, nullptr,
ValueRange(vectorOp->getResults()));
this->convertOpName(srcOp), arrayAttrs, nullptr, operands);

return success();
}
Expand Down
63 changes: 30 additions & 33 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,61 +691,58 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
auto shape = getShape();
int64_t shape_size = static_cast<int64_t>(shape.size());
int64_t shapeSize = static_cast<int64_t>(shape.size());

// Check that the shape size matches the rank of the output tensor
if (shape_size != static_cast<int64_t>(outputType.getRank())) {
return emitOpError("Shape attribute size must match output tensor rank");
// Check that the shape attribute is non-empty.
if (shapeSize == 0) {
return emitOpError("Shape attribute must be non-empty");
}

// Check that the shape attribute is non-empty
if (shape_size == 0) {
return emitOpError("Shape attribute must be non-empty");
// Check that the shape size matches the rank of the output tensor.
if (shapeSize != static_cast<int64_t>(outputType.getRank())) {
return emitOpError() << "Shape attribute size " << shapeSize
<< " must match output tensor rank "
<< outputType.getRank();
}

// Cardinality of the input and output tensors must be the same
// Cardinality of the input and output tensors must be the same.
if (inputType.getNumElements() != outputType.getNumElements()) {
return emitOpError(
"Input and output tensors must have the same number of elements");
return emitOpError() << "Input tensor number of elements "
<< inputType.getNumElements()
<< " and output tensor number of elements "
<< outputType.getNumElements() << " must be the same";
}

bool has_negative = false;
int64_t known_dim_product = 1;
bool hasNegative = false;
auto outputShape = outputType.getShape();

// Check that all dimensions are positive except for at most one -1
// Check that the non-negative dimensions match the output tensor shape
// Calculate the product of the known dimensions
for (int64_t i = 0; i < shape_size; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();
// Check that all dimensions are positive except for at most one -1.
// Check that the non-negative dimensions match the output tensor shape.
// Calculate the product of the known dimensions.
for (int64_t i = 0; i < shapeSize; i++) {
int64_t dimValue = mlir::cast<IntegerAttr>(shape[i]).getInt();

if (dim_value == -1) {
if (has_negative) {
if (dimValue == -1) {
if (hasNegative) {
return emitOpError("Shape attribute must have at most one -1 element");
}
has_negative = true;
hasNegative = true;
} else {
if (dim_value <= 0) {
if (dimValue <= 0) {
return emitOpError(
"All dimensions must be positive except the one with -1");
}

// Ensure that the non-negative dimensions match the output tensor shape
if (dim_value != outputShape[i]) {
return emitOpError("Shape attribute must match the output tensor shape "
"for dimensions that are not -1");
// Ensure that the non-negative dimensions match the output tensor shape.
if (dimValue != outputShape[i]) {
return emitOpError()
<< "Shape attribute " << dimValue
<< " must match the output tensor shape " << outputShape[i]
<< " at index " << i << " for dimension that is not -1";
}

known_dim_product *= dim_value;
}
}

// If there's a -1, ensure that it can be inferred correctly
if (has_negative && inputType.getNumElements() % known_dim_product != 0) {
return emitOpError("Invalid shape: the dimensions do not multiply to the "
"total number of elements in the tensor");
}

return success();
}

Expand Down
50 changes: 24 additions & 26 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,58 +486,56 @@ ::mlir::LogicalResult mlir::tt::ttnn::ReshapeOp::verify() {
auto shape = getShape();
int64_t shapeSize = static_cast<int64_t>(shape.size());

// Check that the shape size matches the rank of the output tensor
if (shapeSize != static_cast<int64_t>(outputType.getRank())) {
return emitOpError("Shape attribute size must match output tensor rank");
}
// Check that the shape attribute is non-empty
// Check that the shape attribute is non-empty.
if (shapeSize == 0) {
return emitOpError("Shape attribute must be non-empty");
}

// Cardinality of the input and output tensors must be the same
// Check that the shape size matches the rank of the output tensor.
if (shapeSize != static_cast<int64_t>(outputType.getRank())) {
return emitOpError() << "Shape attribute size " << shapeSize
<< " must match output tensor rank "
<< outputType.getRank();
}

// Cardinality of the input and output tensors must be the same.
if (inputType.getNumElements() != outputType.getNumElements()) {
return emitOpError(
"Input and output tensors must have the same number of elements");
return emitOpError() << "Input tensor number of elements "
<< inputType.getNumElements()
<< " and output tensor number of elements "
<< outputType.getNumElements() << " must be the same";
}

bool has_negative = false;
int64_t known_dim_product = 1;
bool hasNegative = false;
auto outputShape = outputType.getShape();

// Check that all dimensions are positive except for at most one -1
// Check that the non-negative dimensions match the output tensor shape
// Calculate the product of the known dimensions
for (int64_t i = 0; i < shapeSize; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();
int64_t dimValue = mlir::cast<IntegerAttr>(shape[i]).getInt();

if (dim_value == -1) {
if (has_negative) {
if (dimValue == -1) {
if (hasNegative) {
return emitOpError("Shape attribute must have at most one -1 element");
}
has_negative = true;
hasNegative = true;
} else {
if (dim_value <= 0) {
if (dimValue <= 0) {
return emitOpError(
"All dimensions must be positive except the one with -1");
}

// Ensure that the non-negative dimensions match the output tensor shape
if (dim_value != outputShape[i]) {
return emitOpError("Shape attribute must match the output tensor shape "
"for dimensions that are not -1");
if (dimValue != outputShape[i]) {
return emitOpError()
<< "Shape attribute " << dimValue
<< " must match the output tensor shape " << outputShape[i]
<< " at index " << i << " for dimension that is not -1";
}

known_dim_product *= dim_value;
}
}

// If there's a -1, ensure that it can be inferred correctly
if (has_negative && inputType.getNumElements() % known_dim_product != 0) {
return emitOpError("Invalid shape: the dimensions do not multiply to the "
"total number of elements in the tensor");
}

return success();
}

Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern<ttnn::AllReduceOp> {

// Create a new reshape op.
ttnn::ReshapeOp preReshapeOp = rewriter.create<ttnn::ReshapeOp>(
loc, Type(reshapedInputType), op.getInput(), reshapedInputShapeAttr);
loc, Type(reshapedInputType), op.getInput(), reshapedInputShapeAttr,
/* memory_config */ nullptr);

// Determine new dimension since entire tensor shape got shifted.
dimension = dimension + requiredOnesInput;
Expand Down Expand Up @@ -424,9 +425,9 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern<ttnn::AllReduceOp> {
loc, Type(reshapedOutputType), reduceScatterOp.getResult(),
deviceValue, dimension, clusterAxis);

rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(op, Type(outputType),
allGatherOp.getResult(),
reshapedOutputShapeAttr);
rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, Type(outputType), allGatherOp.getResult(),
reshapedOutputShapeAttr, /* memory_config */ nullptr);
} else {
// TODO(wooseoklee): Once ttnn supports all_reduce op
// (https://github.com/tenstorrent/tt-metal/issues/13835), we can convert
Expand Down
11 changes: 10 additions & 1 deletion lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,16 @@ createReshapeOp(FlatbufferObjectCache &cache, ReshapeOp op) {
auto out = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedSize);

return ::tt::target::ttnn::CreateReshapeOp(*cache.fbb, in, out, shape);
std::optional<mlir::tt::ttnn::MemoryConfigAttr> memoryConfig =
op.getMemoryConfig();
auto tileShape = getTensorValueTileShape(op.getResult());
auto coreRangeSet = getTensorValueCoreRangeSet(cache, op.getResult());

return ::tt::target::ttnn::CreateReshapeOp(
*cache.fbb, in, out, shape,
memoryConfig ? memoryConfigToFlatbuffer(cache, memoryConfig.value(),
tileShape, coreRangeSet)
: 0);
}

template <typename RepeatOp>
Expand Down
7 changes: 5 additions & 2 deletions runtime/lib/ttnn/operations/data_movement/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) {
}
int32_t dim = op->dim();
std::optional<::ttnn::MemoryConfig> memoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
op->memory_config());
op->memory_config() == 0
? ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
::tt::runtime::ttnn::utils::getTensorRefMemoryConfig(op->out()))
: ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
op->memory_config());
::ttnn::Tensor out = ::ttnn::concat(inputs, dim, memoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
Expand Down
9 changes: 8 additions & 1 deletion runtime/lib/ttnn/operations/data_movement/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "operations/data_movement/reshape.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/utils.h"

namespace tt::runtime::ttnn::operations::data_movement {
void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) {
Expand All @@ -13,7 +14,13 @@ void run(const ::tt::target::ttnn::ReshapeOp *op, ProgramContext &context) {
DEBUG_ASSERT(in.is_allocated());
const auto *fbShape = op->shape();
std::vector<int32_t> shape(fbShape->begin(), fbShape->end());
::ttnn::Tensor out = ::ttnn::reshape(in, shape);
std::optional<::ttnn::MemoryConfig> memoryConfig =
op->memory_config() == 0
? ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
::tt::runtime::ttnn::utils::getTensorRefMemoryConfig(op->out()))
: ::tt::runtime::ttnn::utils::createMemoryConfigIfNeeded(
op->memory_config());
::ttnn::Tensor out = ::ttnn::reshape(in, shape, memoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::data_movement
Loading

0 comments on commit 071ce6f

Please sign in to comment.