Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Refactor legal gird analysis #1363

Merged
merged 14 commits into from
Dec 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
Expand All @@ -12,46 +12,49 @@

namespace mlir::tt::ttnn {

struct LegalGridAnalysisInput {
struct LegalLayoutAnalysisInput {
ChipDescAttr chipDesc;
GridAttr maxGrid;
RankedTensorType tensorType;
int64_t maxShardedGrids;
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides;
bool rowMajorEnabled;

LegalGridAnalysisInput()
LegalLayoutAnalysisInput()
: chipDesc(nullptr), maxGrid(nullptr), tensorType(nullptr),
outputLayoutOverrides(nullptr) {}

LegalGridAnalysisInput(
LegalLayoutAnalysisInput(
ChipDescAttr chipDesc, GridAttr maxGrid, RankedTensorType tensorType,
int64_t maxShardedGrids,
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides)
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides,
bool rowMajorEnabled)
: chipDesc(chipDesc), maxGrid(maxGrid), tensorType(tensorType),
maxShardedGrids(maxShardedGrids),
outputLayoutOverrides(outputLayoutOverrides) {}
outputLayoutOverrides(outputLayoutOverrides),
rowMajorEnabled(rowMajorEnabled) {}

bool operator==(const LegalGridAnalysisInput &rhs) const {
bool operator==(const LegalLayoutAnalysisInput &rhs) const {
return chipDesc == rhs.chipDesc && maxGrid == rhs.maxGrid &&
tensorType == rhs.tensorType &&
outputLayoutOverrides == rhs.outputLayoutOverrides;
}

bool operator!=(const LegalGridAnalysisInput &rhs) const {
bool operator!=(const LegalLayoutAnalysisInput &rhs) const {
return !(*this == rhs);
}
};

class LegalGridAnalysis
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<TTNNLayoutAttr>> {
class LegalLayoutAnalysis : public TTNNAnalysis<LegalLayoutAnalysisInput,
std::vector<TTNNLayoutAttr>> {
private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {}
LegalLayoutAnalysis(Operation *op) : TTNNAnalysis(op) {}
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
namespace mlir::tt::ttnn {

struct OpConfigAnalysisInput {
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalGrids;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;

OpConfigAnalysisInput() : legalGrids() {}
OpConfigAnalysisInput() : legalLayouts() {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&&legalGrids)
: legalGrids(std::move(legalGrids)) {}
&&legalLayouts)
: legalLayouts(std::move(legalLayouts)) {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalGrids)
: legalGrids(legalGrids) {}
&legalLayouts)
: legalLayouts(legalLayouts) {}

bool operator==(const OpConfigAnalysisInput &rhs) const {
return legalGrids == rhs.legalGrids;
return legalLayouts == rhs.legalLayouts;
}

bool operator!=(const OpConfigAnalysisInput &rhs) const {
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
bool hasInterleavedL1TensorMemoryLayout() const;
bool isTiled() const;
Type getElementType() const;
Type getScalarElementType() const;
DataType getDataTypeFromMemRef() const;
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
Expand Down
13 changes: 10 additions & 3 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,23 @@ struct TTIRToTTNNBackendPipelineOptions
"Pass in a system descriptor flatbuffer to compile against."),
llvm::cl::init("")};

// Option to override maximum number of legal layouts for grid analysis
// Option to override maximum number of sharded layouts to be generated in
// legal layout analysis.
//
Option<int64_t> maxLegalLayouts{
*this, "max-legal-layouts",
llvm::cl::desc(
"Override maximum number of legal layouts for grid analysis."),
llvm::cl::desc("Override maximum number of sharded layouts for legal "
"layout analysis."),
llvm::cl::init(64)};

ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};

Option<bool> rowMajorEnabled{
*this, "row-major-enabled",
llvm::cl::desc(
"Enable row major layout generation in legal layout analysis."),
llvm::cl::init(false)};
};

// TTIR to EmitC pipeline options.
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct TTNNOptimizerOptions {
MemoryLayoutAnalysisPolicyType::DFSharding;
bool memReconfigEnabled = false;
int64_t maxLegalLayouts = 64;
bool rowMajorEnabled = false;
};

std::unique_ptr<::mlir::Pass> createTTNNOptimizer();
Expand Down
16 changes: 11 additions & 5 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
namespace mlir::tt::ttnn {

struct OutputLayoutOverrideParams {
SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
tt::DataType dataType;
std::optional<SmallVector<int64_t, 2>> grid;
std::optional<BufferType> bufferType;
std::optional<TensorMemoryLayout>
tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType;

bool fullLayoutOverride() const {
return grid.has_value() && bufferType.has_value() &&
tensorMemoryLayout.has_value() && memoryLayout.has_value();
}
};

struct InputLayoutOverrideParams {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRTTNNAnalysis
LegalGridAnalysis.cpp
LegalLayoutAnalysis.cpp
OpConfigAnalysis.cpp
MemoryLayoutAnalysis.cpp
L1ChainConfig.cpp
Expand Down
217 changes: 0 additions & 217 deletions lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp

This file was deleted.

Loading
Loading