Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Refactor legal gird analysis #1363

Merged
merged 14 commits into from
Dec 10, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
Expand All @@ -12,46 +12,49 @@

namespace mlir::tt::ttnn {

struct LegalGridAnalysisInput {
struct LegalLayoutAnalysisInput {
ChipDescAttr chipDesc;
GridAttr maxGrid;
RankedTensorType tensorType;
int64_t maxShardedGrids;
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides;
bool rowMajorEnabled;

LegalGridAnalysisInput()
LegalLayoutAnalysisInput()
: chipDesc(nullptr), maxGrid(nullptr), tensorType(nullptr),
outputLayoutOverrides(nullptr) {}

LegalGridAnalysisInput(
LegalLayoutAnalysisInput(
ChipDescAttr chipDesc, GridAttr maxGrid, RankedTensorType tensorType,
int64_t maxShardedGrids,
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides)
llvm::StringMap<OutputLayoutOverrideParams> *outputLayoutOverrides,
bool rowMajorEnabled)
: chipDesc(chipDesc), maxGrid(maxGrid), tensorType(tensorType),
maxShardedGrids(maxShardedGrids),
outputLayoutOverrides(outputLayoutOverrides) {}
outputLayoutOverrides(outputLayoutOverrides),
rowMajorEnabled(rowMajorEnabled) {}

bool operator==(const LegalGridAnalysisInput &rhs) const {
bool operator==(const LegalLayoutAnalysisInput &rhs) const {
return chipDesc == rhs.chipDesc && maxGrid == rhs.maxGrid &&
tensorType == rhs.tensorType &&
outputLayoutOverrides == rhs.outputLayoutOverrides;
}

bool operator!=(const LegalGridAnalysisInput &rhs) const {
bool operator!=(const LegalLayoutAnalysisInput &rhs) const {
return !(*this == rhs);
}
};

class LegalGridAnalysis
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<TTNNLayoutAttr>> {
class LegalLayoutAnalysis : public TTNNAnalysis<LegalLayoutAnalysisInput,
std::vector<TTNNLayoutAttr>> {
private:
void analysisImplementation() override;
bool applyOverrides() override;

public:
LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {}
LegalLayoutAnalysis(Operation *op) : TTNNAnalysis(op) {}
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H
14 changes: 7 additions & 7 deletions include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
namespace mlir::tt::ttnn {

struct OpConfigAnalysisInput {
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalGrids;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;

OpConfigAnalysisInput() : legalGrids() {}
OpConfigAnalysisInput() : legalLayouts() {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&&legalGrids)
: legalGrids(std::move(legalGrids)) {}
&&legalLayouts)
: legalLayouts(std::move(legalLayouts)) {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalGrids)
: legalGrids(legalGrids) {}
&legalLayouts)
: legalLayouts(legalLayouts) {}

bool operator==(const OpConfigAnalysisInput &rhs) const {
return legalGrids == rhs.legalGrids;
return legalLayouts == rhs.legalLayouts;
}

bool operator!=(const OpConfigAnalysisInput &rhs) const {
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
Layout getLayout() const;
std::optional<TensorMemoryLayout> getMemLayoutOpt() const;
Type getElementType() const;
Type getScalarElementType() const;
uint64_t getShardSizeInBytes() const;
BufferType getBufferType() const;
DataType getDataType() const;
Expand Down
31 changes: 22 additions & 9 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct TTIRToTTNNBackendPipelineOptions
llvm::cl::desc("Determine and set max valid grid for Op execution."),
llvm::cl::init(false)};

// Option to manually insert TTIR_ToLayoutOp for specific op's operand.
// Option to manually insert TTNN_ToLayoutOp for specific op's operand.
// The format is a comma separated list of op names and operand index
// separated by ':' separator.
//
Expand All @@ -43,9 +43,13 @@ struct TTIRToTTNNBackendPipelineOptions
"Manually insert memory reconfig op for specific op's operand."),
llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};

// Option to override output layout for specific ops.
// The format is a comma separated list of op names equal to the output layout
// params separated by ":"
// Option to override output layout for specific operations. You can
// override any number or combination of layout parameters. If not all are
// overridden, the remaining ones will be inferred with all possible
// combinations generated in LegalLayoutAnalysis. The format is a
// comma-separated list of operation names followed by the output layout
// parameters, separated by :. The order of parameters does not matter; the
// parser will deduce which one is being overridden based on its value.
//
// op_name=grid_size:memory_space:tensor_memory_layout:memory_layout:data_type
//
Expand All @@ -58,7 +62,9 @@ struct TTIRToTTNNBackendPipelineOptions
// bfp_bf2, u32, u16, u8
//
// Full Example:
// "op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16"
// "op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:f16"
// Partial Example:
// "op1=2x2:block_sharded"
//
//
// Note: This option is only valid if optimizerPassEnabled is true.
Expand Down Expand Up @@ -101,19 +107,26 @@ struct TTIRToTTNNBackendPipelineOptions
"Pass in a system descriptor flatbuffer to compile against."),
llvm::cl::init("")};

// Option to override maximum number of legal layouts for grid analysis
// Option to override maximum number of sharded layouts to be generated in
// legal layout analysis.
//
Option<int64_t> maxLegalLayouts{
*this, OptionNames::maxLegalLayouts,
llvm::cl::desc(
"Override maximum number of legal layouts for grid analysis."),
llvm::cl::desc("Override maximum number of sharded layouts for legal "
"layout analysis."),
llvm::cl::init(64)};

ListOption<int64_t> meshShape{
*this, OptionNames::meshShape,
llvm::cl::desc("Set the multi-device mesh shape.")};

// Options to enable/disable the workaround pass.
Option<bool> rowMajorEnabled{
*this, "row-major-enabled",
llvm::cl::desc(
"Enable row major layout generation in legal layout analysis."),
llvm::cl::init(false)};

// Option to enable/disable the workaround pass.
//
Option<bool> layouotWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct TTNNOptimizerOptions {
MemoryLayoutAnalysisPolicyType::DFSharding;
bool memReconfigEnabled = false;
int64_t maxLegalLayouts = 64;
bool rowMajorEnabled = false;
};

std::unique_ptr<::mlir::Pass> createTTNNOptimizer();
Expand Down
79 changes: 69 additions & 10 deletions include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,77 @@ struct OptionNames {
};

struct OutputLayoutOverrideParams {

SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
mlir::tt::DataType dataType;
std::optional<SmallVector<int64_t, 2>> grid;
std::optional<BufferType> bufferType;
std::optional<TensorMemoryLayout>
tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType;

// Check if all layout parameters that are generated in LegalLayoutAnalysis
// are overridden. DataType is the only that is not.
bool fullLayoutOverride() const {
return grid.has_value() && bufferType.has_value() &&
tensorMemoryLayout.has_value() && memoryLayout.has_value();
}

bool operator==(const OutputLayoutOverrideParams rhs) const {
return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] &&
bufferType == rhs.bufferType &&
tensorMemoryLayout == rhs.tensorMemoryLayout &&
memoryLayout == rhs.memoryLayout && dataType == rhs.dataType;
if (grid.has_value() != rhs.grid.has_value()) {
return false;
}

if (grid.has_value() && rhs.grid.has_value()) {
if (grid.value().size() != rhs.grid.value().size()) {
return false;
}
for (std::size_t i = 0; i < grid.value().size(); i++) {
if (grid.value()[i] != rhs.grid.value()[i]) {
return false;
}
}
}

if (bufferType.has_value() != rhs.bufferType.has_value()) {
return false;
}

if (bufferType.has_value() && rhs.bufferType.has_value()) {
if (bufferType.value() != rhs.bufferType.value()) {
return false;
}
}

if (tensorMemoryLayout.has_value() != rhs.tensorMemoryLayout.has_value()) {
return false;
}

if (tensorMemoryLayout.has_value() && rhs.tensorMemoryLayout.has_value()) {
if (tensorMemoryLayout.value() != rhs.tensorMemoryLayout.value()) {
return false;
}
}

if (memoryLayout.has_value() != rhs.memoryLayout.has_value()) {
return false;
}

if (memoryLayout.has_value() && rhs.memoryLayout.has_value()) {
if (memoryLayout.value() != rhs.memoryLayout.value()) {
return false;
}
}

if (dataType.has_value() != rhs.dataType.has_value()) {
return false;
}

if (dataType.has_value() && rhs.dataType.has_value()) {
if (dataType.value() != rhs.dataType.value()) {
return false;
}
}

return true;
}

bool operator!=(const OutputLayoutOverrideParams &rhs) const {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIRTTNNAnalysis
LegalGridAnalysis.cpp
LegalLayoutAnalysis.cpp
OpConfigAnalysis.cpp
MemoryLayoutAnalysis.cpp
L1ChainConfig.cpp
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ void L1InterleavedPolicy::run() {
}

bool L1InterleavedPolicy::isAnalyzable(Operation *op) {
// Skip operations that are not analyzed by the LegalGridAnalysis.
// Skip operations that are not analyzed by the LegalLayoutAnalysis.
//
if (legalLayouts.count(op) > 0) {
// Skip operations that are filterd out by the MemoryLayoutAnalysis.
Expand Down
Loading
Loading