diff --git a/include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/LegalLayoutAnalysis.h similarity index 61% rename from include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h rename to include/ttmlir/Dialect/TTNN/Analysis/LegalLayoutAnalysis.h index d26db60f41..dc8c6054c6 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/LegalLayoutAnalysis.h @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H -#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H +#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H +#define TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h" @@ -12,46 +12,49 @@ namespace mlir::tt::ttnn { -struct LegalGridAnalysisInput { +struct LegalLayoutAnalysisInput { ChipDescAttr chipDesc; GridAttr maxGrid; RankedTensorType tensorType; int64_t maxShardedGrids; llvm::StringMap *outputLayoutOverrides; + bool rowMajorEnabled; - LegalGridAnalysisInput() + LegalLayoutAnalysisInput() : chipDesc(nullptr), maxGrid(nullptr), tensorType(nullptr), outputLayoutOverrides(nullptr) {} - LegalGridAnalysisInput( + LegalLayoutAnalysisInput( ChipDescAttr chipDesc, GridAttr maxGrid, RankedTensorType tensorType, int64_t maxShardedGrids, - llvm::StringMap *outputLayoutOverrides) + llvm::StringMap *outputLayoutOverrides, + bool rowMajorEnabled) : chipDesc(chipDesc), maxGrid(maxGrid), tensorType(tensorType), maxShardedGrids(maxShardedGrids), - outputLayoutOverrides(outputLayoutOverrides) {} + outputLayoutOverrides(outputLayoutOverrides), + rowMajorEnabled(rowMajorEnabled) {} - bool operator==(const LegalGridAnalysisInput &rhs) const { + bool operator==(const LegalLayoutAnalysisInput &rhs) const { return chipDesc == rhs.chipDesc && maxGrid == rhs.maxGrid && tensorType == rhs.tensorType && outputLayoutOverrides == rhs.outputLayoutOverrides; } - bool operator!=(const LegalGridAnalysisInput &rhs) const { + bool operator!=(const LegalLayoutAnalysisInput &rhs) const { return !(*this == rhs); } }; -class LegalGridAnalysis - : public TTNNAnalysis> { +class LegalLayoutAnalysis : public TTNNAnalysis> { private: void analysisImplementation() override; bool applyOverrides() override; public: - LegalGridAnalysis(Operation *op) : TTNNAnalysis(op) {} + LegalLayoutAnalysis(Operation *op) : TTNNAnalysis(op) {} }; } // namespace mlir::tt::ttnn -#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALGRIDANALYSIS_H +#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_LEGALLAYOUTANALYSIS_H diff --git a/include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h b/include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h index a9df15aa05..948c66f178 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h @@ -11,22 +11,22 @@ namespace mlir::tt::ttnn { struct OpConfigAnalysisInput { - llvm::DenseMap> legalGrids; + llvm::DenseMap> legalLayouts; - OpConfigAnalysisInput() : legalGrids() {} + OpConfigAnalysisInput() : legalLayouts() {} OpConfigAnalysisInput( const llvm::DenseMap> - &&legalGrids) - : legalGrids(std::move(legalGrids)) {} + &&legalLayouts) + : legalLayouts(std::move(legalLayouts)) {} OpConfigAnalysisInput( const llvm::DenseMap> - &legalGrids) - : legalGrids(legalGrids) {} + &legalLayouts) + : legalLayouts(legalLayouts) {} bool operator==(const OpConfigAnalysisInput &rhs) const { - return legalGrids == rhs.legalGrids; + return legalLayouts == rhs.legalLayouts; } bool operator!=(const OpConfigAnalysisInput &rhs) const { diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index ba8b7a724e..e483b07bf2 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -158,6 +158,7 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { Layout getLayout() const; std::optional getMemLayoutOpt() const; Type getElementType() const; + Type getScalarElementType() const; uint64_t getShardSizeInBytes() const; BufferType getBufferType() const; DataType getDataType() const; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 248e68a04c..d27c488eda 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -24,7 +24,7 @@ struct TTIRToTTNNBackendPipelineOptions llvm::cl::desc("Determine and set max valid grid for Op execution."), llvm::cl::init(false)}; - // Option to manually insert TTIR_ToLayoutOp for specific op's operand. + // Option to manually insert TTNN_ToLayoutOp for specific op's operand. // The format is a comma separated list of op names and operand index // separated by ':' separator. // @@ -43,9 +43,13 @@ struct TTIRToTTNNBackendPipelineOptions "Manually insert memory reconfig op for specific op's operand."), llvm::cl::init(llvm::StringMap())}; - // Option to override output layout for specific ops. - // The format is a comma separated list of op names equal to the output layout - // params separated by ":" + // Option to override output layout for specific operations. You can + // override any number or combination of layout parameters. If not all are + // overridden, the remaining ones will be inferred with all possible + // combinations generated in LegalLayoutAnalysis. The format is a + // comma-separated list of operation names followed by the output layout + // parameters, separated by :. The order of parameters does not matter; the + // parser will deduce which one is being overridden based on its value. // // op_name=grid_size:memory_space:tensor_memory_layout:memory_layout:data_type // @@ -58,7 +62,9 @@ struct TTIRToTTNNBackendPipelineOptions // bfp_bf2, u32, u16, u8 // // Full Example: - // "op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16" + // "op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:f16" + // Partial Example: + // "op1=2x2:block_sharded" // // // Note: This option is only valid if optimizerPassEnabled is true. @@ -101,19 +107,26 @@ struct TTIRToTTNNBackendPipelineOptions "Pass in a system descriptor flatbuffer to compile against."), llvm::cl::init("")}; - // Option to override maximum number of legal layouts for grid analysis + // Option to override maximum number of sharded layouts to be generated in + // legal layout analysis. // Option maxLegalLayouts{ *this, OptionNames::maxLegalLayouts, - llvm::cl::desc( - "Override maximum number of legal layouts for grid analysis."), + llvm::cl::desc("Override maximum number of sharded layouts for legal " + "layout analysis."), llvm::cl::init(64)}; ListOption meshShape{ *this, OptionNames::meshShape, llvm::cl::desc("Set the multi-device mesh shape.")}; - // Options to enable/disable the workaround pass. + Option rowMajorEnabled{ + *this, "row-major-enabled", + llvm::cl::desc( + "Enable row major layout generation in legal layout analysis."), + llvm::cl::init(false)}; + + // Option to enable/disable the workaround pass. // Option layouotWorkaroundsEnabled{ *this, "enable-layout-workaround-pass", diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h b/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h index 494e0ff1b8..cdc03b5d72 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h @@ -23,6 +23,7 @@ struct TTNNOptimizerOptions { MemoryLayoutAnalysisPolicyType::DFSharding; bool memReconfigEnabled = false; int64_t maxLegalLayouts = 64; + bool rowMajorEnabled = false; }; std::unique_ptr<::mlir::Pass> createTTNNOptimizer(); diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h index 35f93062e8..cd2d3585f8 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -31,18 +31,77 @@ struct OptionNames { }; struct OutputLayoutOverrideParams { - - SmallVector grid; - BufferType bufferType; - TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... - Layout memoryLayout; // ROW_MAJOR / TILE - mlir::tt::DataType dataType; + std::optional> grid; + std::optional bufferType; + std::optional + tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + std::optional memoryLayout; // ROW_MAJOR / TILE + std::optional dataType; + + // Check if all layout parameters that are generated in LegalLayoutAnalysis + // are overridden. DataType is the only that is not. + bool fullLayoutOverride() const { + return grid.has_value() && bufferType.has_value() && + tensorMemoryLayout.has_value() && memoryLayout.has_value(); + } bool operator==(const OutputLayoutOverrideParams rhs) const { - return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] && - bufferType == rhs.bufferType && - tensorMemoryLayout == rhs.tensorMemoryLayout && - memoryLayout == rhs.memoryLayout && dataType == rhs.dataType; + if (grid.has_value() != rhs.grid.has_value()) { + return false; + } + + if (grid.has_value() && rhs.grid.has_value()) { + if (grid.value().size() != rhs.grid.value().size()) { + return false; + } + for (std::size_t i = 0; i < grid.value().size(); i++) { + if (grid.value()[i] != rhs.grid.value()[i]) { + return false; + } + } + } + + if (bufferType.has_value() != rhs.bufferType.has_value()) { + return false; + } + + if (bufferType.has_value() && rhs.bufferType.has_value()) { + if (bufferType.value() != rhs.bufferType.value()) { + return false; + } + } + + if (tensorMemoryLayout.has_value() != rhs.tensorMemoryLayout.has_value()) { + return false; + } + + if (tensorMemoryLayout.has_value() && rhs.tensorMemoryLayout.has_value()) { + if (tensorMemoryLayout.value() != rhs.tensorMemoryLayout.value()) { + return false; + } + } + + if (memoryLayout.has_value() != rhs.memoryLayout.has_value()) { + return false; + } + + if (memoryLayout.has_value() && rhs.memoryLayout.has_value()) { + if (memoryLayout.value() != rhs.memoryLayout.value()) { + return false; + } + } + + if (dataType.has_value() != rhs.dataType.has_value()) { + return false; + } + + if (dataType.has_value() && rhs.dataType.has_value()) { + if (dataType.value() != rhs.dataType.value()) { + return false; + } + } + + return true; } bool operator!=(const OutputLayoutOverrideParams &rhs) const { diff --git a/lib/Dialect/TTNN/Analysis/CMakeLists.txt b/lib/Dialect/TTNN/Analysis/CMakeLists.txt index 640702f71c..4db2d78b9c 100644 --- a/lib/Dialect/TTNN/Analysis/CMakeLists.txt +++ b/lib/Dialect/TTNN/Analysis/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRTTNNAnalysis - LegalGridAnalysis.cpp + LegalLayoutAnalysis.cpp OpConfigAnalysis.cpp MemoryLayoutAnalysis.cpp L1ChainConfig.cpp diff --git a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp index 23c1b306ab..69a07af168 100644 --- a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp @@ -309,7 +309,7 @@ void L1InterleavedPolicy::run() { } bool L1InterleavedPolicy::isAnalyzable(Operation *op) { - // Skip operations that are not analyzed by the LegalGridAnalysis. + // Skip operations that are not analyzed by the LegalLayoutAnalysis. // if (legalLayouts.count(op) > 0) { // Skip operations that are filterd out by the MemoryLayoutAnalysis. diff --git a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp deleted file mode 100644 index 9bbbccf5ea..0000000000 --- a/lib/Dialect/TTNN/Analysis/LegalGridAnalysis.cpp +++ /dev/null @@ -1,225 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNN.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/Utils/Utils.h" - -namespace mlir::tt::ttnn { - -bool mock_is_output_tensor_legal_for_op(Operation *op, TTNNLayoutAttr layout) { - // Placeholder, needs to be replaced with a call the the TTNN op interface. - return true; -} - -bool tensor_shape_compatible_with_shard(Operation *op, TTNNLayoutAttr layout) { - // These constraints are implemented seperatelly in every TTNN op. - // Almost nothing seems to be shared between EVERY op, so is hard to have any - // logic here without the risk of discarding a valid configuraiton or modeling - // the constraint for each op. This logic may be offloaded to the TTNN op - // interface. - - // For now we will check if the tilised tensor dims are divisible by the grid - // dims. This will definitly discard possible valid configurations, but is a - // start. - RankedTensorType tensorType = - mlir::cast(op->getResult(0).getType()); - llvm::ArrayRef tensorShape = tensorType.getShape(); - - int64_t MTiles = 1; - if (tensorType.getRank() >= 2) { - MTiles = (tensorShape.rbegin()[1] + 31) / 32; - } - - int64_t KTIles = (tensorShape.back() + 31) / 32; - - int64_t gridR = layout.getGrid().getShape()[0]; - int64_t gridC = layout.getGrid().getShape()[1]; - - return (MTiles % gridR == 0) && (KTIles % gridC == 0); -} - -bool cantChangeOutputLayout(Operation *op) { - // Check if OP belongs to TTNN dialect. - // - if (!isa(op->getDialect())) { - return true; - } - - if (llvm::isa(op)) { - return true; - } - - if (llvm::isa(op)) { - return true; - } - - return false; -} - -bool LegalGridAnalysis::applyOverrides() { - // Lookup layout overrides based on location information for current - // operation. - // - - if (not analysisInput.outputLayoutOverrides) { - return false; - } - - if (not isa(op->getLoc())) { - return false; - } - - StringRef opLocName = mlir::cast(op->getLoc()).getName(); - auto gridOverride = analysisInput.outputLayoutOverrides->find(opLocName); - - if (gridOverride == analysisInput.outputLayoutOverrides->end()) { - return false; - } - - OutputLayoutOverrideParams override = gridOverride->getValue(); - RankedTensorType tensorType = - mlir::cast(op->getResult(0).getType()); - TTNNLayoutAttr layout = mlir::cast(tensorType.getEncoding()); - - GridAttr grid = - GridAttr::get(op->getContext(), ArrayRef(override.grid)); - - // Create element type for the new layout. - Type elementType = - utils::createRowMajorTypeFromDtype(op->getContext(), override.dataType); - if (override.memoryLayout == Layout::Tile) { - elementType = TileType::get(op->getContext(), elementType); - } - - analysisResult.push_back( - layout.withGrid(op->getContext(), tensorType, grid) - .withBufferType(op->getContext(), override.bufferType) - .withMemoryLayout(op->getContext(), override.tensorMemoryLayout) - .withElementType(op->getContext(), elementType)); - - return true; -} - -void LegalGridAnalysis::analysisImplementation() { - // A first incomplete implementation of the LegalGridAnalysis. - // This implementation is a placeholder and is meant to just enable testing of - // other components. - - // Skip operations that don't have output tensors. - if (op->getNumResults() == 0) { - return; - } - - if (!isa(op->getResult(0).getType())) { - return; - } - - if (llvm::isa(op)) { - return; - } - - // Get output tensor type. - RankedTensorType tensorType = - mlir::cast(op->getResult(0).getType()); - TTNNLayoutAttr layout = mlir::cast(tensorType.getEncoding()); - - // Return existing layout if it is not possible to change it. - if (cantChangeOutputLayout(op)) { - analysisResult.push_back(layout); - return; - } - - // DRAM - // No grid is set since the tensor is not sharded. - // TODO(odjuricic): We need to set grid here since it will be used as the - // compute gird. (not implemented in runtime atm) - TTNNLayoutAttr dram = - layout.withBufferType(op->getContext(), BufferType::DRAM) - .withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved) - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), - analysisInput.maxGrid.getShape())); - if (mock_is_output_tensor_legal_for_op(op, dram)) { - analysisResult.push_back(dram); - } - - // L1 Interleaved (same as above). - TTNNLayoutAttr l1Interleaved = - layout.withBufferType(op->getContext(), BufferType::L1) - .withMemoryLayout(op->getContext(), TensorMemoryLayout::Interleaved) - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), - analysisInput.maxGrid.getShape())); - if (mock_is_output_tensor_legal_for_op(op, l1Interleaved)) { - analysisResult.push_back(l1Interleaved); - } - - // L1 Sharded - TTNNLayoutAttr shardedBase = - layout.withBufferType(op->getContext(), BufferType::L1); - std::vector shardedResults; - - // Block Sharded - for (auto width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) { - for (auto height = 1; height <= analysisInput.maxGrid.getShape()[1]; - ++height) { - shardedResults.push_back( - shardedBase - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {width, height})) - .withMemoryLayout(op->getContext(), - TensorMemoryLayout::BlockSharded)); - } - } - - auto numCores = - analysisInput.maxGrid.getShape()[0] * analysisInput.maxGrid.getShape()[1]; - // Height Sharded - // TODO(odjuricic): Missing affine mapping to actual grid. Need to check with - // runtime implementation on what to produce here. - for (auto height = 1; height <= numCores; ++height) { - shardedResults.push_back( - shardedBase - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {height, 1})) - .withMemoryLayout(op->getContext(), - TensorMemoryLayout::HeightSharded)); - } - - // Width Sharded - for (auto width = 1; width <= numCores; ++width) { - shardedResults.push_back( - shardedBase - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {1, width})) - .withMemoryLayout(op->getContext(), - TensorMemoryLayout::WidthSharded)); - } - - // Filter layouts based on output tensor legality for current op. - shardedResults.erase( - std::remove_if(shardedResults.begin(), shardedResults.end(), - [this](TTNNLayoutAttr layout) { - return !tensor_shape_compatible_with_shard(op, layout) || - !mock_is_output_tensor_legal_for_op(op, layout); - }), - shardedResults.end()); - - // Pick top largest sharded grids. - std::sort(shardedResults.begin(), shardedResults.end(), - [](TTNNLayoutAttr a, TTNNLayoutAttr b) { - return a.getGrid().getGridVolume() > b.getGrid().getGridVolume(); - }); - - analysisResult.insert( - analysisResult.end(), shardedResults.begin(), - shardedResults.begin() + - std::min(analysisInput.maxShardedGrids, - static_cast(shardedResults.size()))); -} -} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp new file mode 100644 index 0000000000..3f4ef25ab2 --- /dev/null +++ b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp @@ -0,0 +1,321 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Analysis/LegalLayoutAnalysis.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" + +namespace mlir::tt::ttnn { + +bool mockIsOutputTensorLegalForOp(Operation *op, TTNNLayoutAttr layout) { + // Placeholder, needs to be replaced with a call the the TTNN op interface. + return true; +} + +bool tensorShapeCompatibleWithShard(Operation *op, TTNNLayoutAttr layout) { + // These constraints are implemented seperatelly in every TTNN op. + // Almost nothing seems to be shared between EVERY op, so is hard to have any + // logic here without the risk of discarding a valid configuraiton or modeling + // the constraint for each op. + + // For now just check if we have enough tiles to shard the tensor to the + // desired grid. This is a safe heuristic that should be valid for all ops. + + if (not layout.hasShardedTensorMemoryLayout()) { + return true; + } + + if (layout.isTiled()) { + RankedTensorType tensorType = + mlir::cast(op->getResult(0).getType()); + llvm::ArrayRef tensorShape = tensorType.getShape(); + llvm::SmallVector tiledShape = + layout.getTiledShape(tensorShape); + llvm::ArrayRef gridShape = layout.getGrid().getShape(); + + assert(tiledShape.size() == gridShape.size() && + "Tiled tensor shape and grid shape must have the same rank"); + + for (size_t i = 0; i < tiledShape.size(); i++) { + // We need to have at least as many tiles as the grid size. + // Could also experiment with tiledShape[i] % gridShape[i] == 0, but need + // more context. + if (tiledShape[i] < gridShape[i]) { + return false; + } + } + return true; + } + + // TODO(odjuricic): For row major there are no constraints on how the tensor + // can be sharded. We need some kind of a heuristic to reduce the search + // space. + return true; +} + +bool cantChangeOutputLayout(Operation *op) { + // Check if OP belongs to TTNN dialect. + // + if (!isa(op->getDialect())) { + return true; + } + + if (llvm::isa(op)) { + return true; + } + + if (llvm::isa(op)) { + return true; + } + + return false; +} + +bool LegalLayoutAnalysis::applyOverrides() { + // Lookup layout overrides based on location information for current + // operation. + // + + if (not analysisInput.outputLayoutOverrides) { + return false; + } + + if (not isa(op->getLoc())) { + return false; + } + + StringRef opLocName = mlir::cast(op->getLoc()).getName(); + auto overrideIt = analysisInput.outputLayoutOverrides->find(opLocName); + + if (overrideIt == analysisInput.outputLayoutOverrides->end()) { + return false; + } + + OutputLayoutOverrideParams layoutOverride = overrideIt->getValue(); + + // If all layout parameters are set (except data type), we can skip analysis + // and create the overriden layout. Otherwise, we need to perform analysis and + // apply partial overrides. + if (not layoutOverride.fullLayoutOverride()) { + return false; + } + + RankedTensorType tensorType = + mlir::cast(op->getResult(0).getType()); + TTNNLayoutAttr layout = mlir::cast(tensorType.getEncoding()); + llvm::ArrayRef tensorShape = tensorType.getShape(); + + GridAttr grid = GridAttr::get(op->getContext(), + ArrayRef(layoutOverride.grid.value())); + + // Create element type for the new layout. + Type elementType = layout.getScalarElementType(); + if (layoutOverride.dataType.has_value()) { + elementType = utils::createRowMajorTypeFromDtype( + op->getContext(), layoutOverride.dataType.value()); + } + + if (layoutOverride.memoryLayout == Layout::Tile) { + elementType = TileType::get(op->getContext(), elementType); + } + + analysisResult.push_back(TTNNLayoutAttr::get( + op->getContext(), tensorShape, elementType, + layoutOverride.bufferType.value(), grid, + TensorMemoryLayoutAttr::get(op->getContext(), + layoutOverride.tensorMemoryLayout.value()))); + + return true; +} + +bool incompatibleWithOverride( + const TTNNLayoutAttr &layout, + const std::optional &layoutOverride) { + if (not layoutOverride.has_value()) { + return false; + } + + if (layoutOverride->grid.has_value()) { + if (layout.getGrid().getShape()[0] != layoutOverride->grid.value()[0] || + layout.getGrid().getShape()[1] != layoutOverride->grid.value()[1]) { + return true; + } + } + if (layoutOverride->bufferType.has_value() && + layout.getBufferType() != layoutOverride->bufferType.value()) { + return true; + } + if (layoutOverride->tensorMemoryLayout.has_value() && + layout.getMemLayout().getValue() != + layoutOverride->tensorMemoryLayout.value()) { + return true; + } + if (layoutOverride->memoryLayout.has_value() && + layout.isTiled() != + (layoutOverride->memoryLayout.value() == Layout::Tile)) { + return true; + } + return false; +} + +void LegalLayoutAnalysis::analysisImplementation() { + // Skip operations that don't have output tensors. + if (op->getNumResults() == 0) { + return; + } + + if (!isa(op->getResult(0).getType())) { + return; + } + + if (llvm::isa(op)) { + return; + } + + // Get output tensor type. + RankedTensorType tensorType = + mlir::cast(op->getResult(0).getType()); + llvm::ArrayRef tensorShape = tensorType.getShape(); + TTNNLayoutAttr layout = mlir::cast(tensorType.getEncoding()); + + // Return existing layout if it is not possible to change it. + if (cantChangeOutputLayout(op)) { + analysisResult.push_back(layout); + return; + } + + Type scalarElementType = layout.getScalarElementType(); + + std::optional override; + + // Check if we have an override for this op. + if (isa(op->getLoc())) { + StringRef opLocName = mlir::cast(op->getLoc()).getName(); + if (auto overrideIt = analysisInput.outputLayoutOverrides->find(opLocName); + overrideIt != analysisInput.outputLayoutOverrides->end()) { + override = overrideIt->getValue(); + if (override->dataType.has_value()) { + scalarElementType = {utils::createRowMajorTypeFromDtype( + op->getContext(), override->dataType.value())}; + } + } + } + + Type tileElementType = TileType::get(op->getContext(), scalarElementType); + std::vector shardedResults; + + bool rowMajorAllowed = analysisInput.rowMajorEnabled; + if (override.has_value() && override->memoryLayout.has_value() && + override->memoryLayout.value() == Layout::RowMajor) { + // Force allow row major if override is set. + rowMajorAllowed = true; + } + + // Generate both TILE and ROW_MAJOR layouts. + for (Type elementType : {scalarElementType, tileElementType}) { + if (not rowMajorAllowed && elementType == scalarElementType) { + continue; + } + // DRAM + analysisResult.push_back(TTNNLayoutAttr::get( + op->getContext(), tensorShape, elementType, BufferType::DRAM, + analysisInput.maxGrid, + TensorMemoryLayoutAttr::get(op->getContext(), + TensorMemoryLayout::Interleaved))); + + // L1 Interleaved (same as above). + analysisResult.push_back(TTNNLayoutAttr::get( + op->getContext(), tensorShape, elementType, BufferType::L1, + analysisInput.maxGrid, + TensorMemoryLayoutAttr::get(op->getContext(), + TensorMemoryLayout::Interleaved))); + + // L1 Sharded + TTNNLayoutAttr shardedBase = + layout.withBufferType(op->getContext(), BufferType::L1) + .withElementType(op->getContext(), elementType); + + assert(analysisInput.maxGrid.getShape().size() == 2 && + "Max device grid is expected to be 2D."); + // Block Sharded + for (int width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) { + for (int height = 1; height <= analysisInput.maxGrid.getShape()[1]; + ++height) { + shardedResults.push_back( + shardedBase + .withGrid(op->getContext(), tensorType, + GridAttr::get(op->getContext(), {width, height})) + .withMemoryLayout(op->getContext(), + TensorMemoryLayout::BlockSharded)); + } + } + + int64_t numCores = analysisInput.maxGrid.getGridVolume(); + // Height Sharded + // TODO(odjuricic): Missing affine mapping to actual grid. Need to check + // with runtime implementation on what to produce here. + for (int height = 1; height <= numCores; ++height) { + shardedResults.push_back( + shardedBase + .withGrid(op->getContext(), tensorType, + GridAttr::get(op->getContext(), {height, 1})) + .withMemoryLayout(op->getContext(), + TensorMemoryLayout::HeightSharded)); + } + + // Width Sharded + for (int width = 1; width <= numCores; ++width) { + shardedResults.push_back( + shardedBase + .withGrid(op->getContext(), tensorType, + GridAttr::get(op->getContext(), {1, width})) + .withMemoryLayout(op->getContext(), + TensorMemoryLayout::WidthSharded)); + } + } + + // Filter layouts based on output tensor legality for current op. + shardedResults.erase( + std::remove_if(shardedResults.begin(), shardedResults.end(), + [this](TTNNLayoutAttr layout) { + return !tensorShapeCompatibleWithShard(op, layout) || + !mockIsOutputTensorLegalForOp(op, layout); + }), + shardedResults.end()); + + // Pick top largest sharded grids. + // This becomes a problem when we introduce row_major since an 8x8 tensor can + // be sharded onto a 8x8 grid. + std::sort(shardedResults.begin(), shardedResults.end(), + [](TTNNLayoutAttr a, TTNNLayoutAttr b) { + return a.getGrid().getGridVolume() > b.getGrid().getGridVolume(); + }); + + analysisResult.insert( + analysisResult.end(), shardedResults.begin(), + shardedResults.begin() + + std::min(analysisInput.maxShardedGrids, + static_cast(shardedResults.size()))); + + // Apply partial layout overrides. Remove layouts that conflict with at least + // one overriden param. + if (override.has_value()) { + auto shouldRemoveLayout = + std::bind(incompatibleWithOverride, std::placeholders::_1, override); + analysisResult.erase(std::remove_if(analysisResult.begin(), + analysisResult.end(), + shouldRemoveLayout), + analysisResult.end()); + } + + if (analysisResult.empty()) { + op->emitError("No legal layout found for the operation."); + assert(false && "At least one legal layout must be found."); + } +} +} // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp b/lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp index d4a79d64e9..f10964f475 100644 --- a/lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/OpConfigAnalysis.cpp @@ -16,12 +16,10 @@ bool OpConfigAnalysis::applyOverrides() { void OpConfigAnalysis::analysisImplementation() { // Future entrypoint for picking optimal op config. - // Placeholder: pick the first legal grid. + // Placeholder: pick the first legal layout. // - for (auto opGrids : analysisInput.legalGrids) { - if (not opGrids.second.empty()) { - analysisResult[opGrids.first] = opGrids.second[0]; - } + for (auto opLayouts : analysisInput.legalLayouts) { + analysisResult[opLayouts.first] = opLayouts.second[0]; } } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index 3f6c88e2b1..43c5984ed9 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -132,6 +132,16 @@ mlir::Type TTNNLayoutAttr::getElementType() const { return getMemref().getElementType(); } +// If the element type is TileType, return the nested element type i.e +// FloatType/IntegerType +mlir::Type TTNNLayoutAttr::getScalarElementType() const { + Type elementType = getElementType(); + if (mlir::isa(elementType)) { + return mlir::cast(elementType).getElementType(); + } + return elementType; +} + // Get scalar element type. // Example: memref<2x2xf32> -> f32 // Example: memref<2x2x!tt.tile<32x32xf32>> -> f32 diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index a7e25f0288..b2f257a7e4 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -48,6 +48,7 @@ void createTTNNPipelineAnalysisPasses( optimizerOptions.memoryLayoutAnalysisPolicy = options.memoryLayoutAnalysisPolicy; optimizerOptions.maxLegalLayouts = options.maxLegalLayouts; + optimizerOptions.rowMajorEnabled = options.rowMajorEnabled; pm.addPass(mlir::tt::ttnn::createTTNNOptimizer(optimizerOptions)); } } diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 51f731841a..9ada2dbb5d 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -5,7 +5,7 @@ #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" -#include "ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h" +#include "ttmlir/Dialect/TTNN/Analysis/LegalLayoutAnalysis.h" #include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h" #include "ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" @@ -79,6 +79,7 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> { memReconfigEnabled = std::move(options.memReconfigEnabled); memoryLayoutAnalysisPolicy = std::move(options.memoryLayoutAnalysisPolicy); maxLegalLayouts = std::move(options.maxLegalLayouts); + rowMajorEnabled = std::move(options.rowMajorEnabled); } protected: @@ -111,9 +112,14 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> { llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)}; ::mlir::Pass::Option maxLegalLayouts{ *this, "max-legal-layouts", - ::llvm::cl::desc( - "Override maximum number of legal layouts for grid analysis."), + ::llvm::cl::desc("Override maximum number of sharded layouts for legal " + "layout analysis."), ::llvm::cl::init(64)}; + ::mlir::Pass::Option rowMajorEnabled{ + *this, "row-major-enabled", + ::llvm::cl::desc( + "Enable row major layout generation in legal layout analysis."), + ::llvm::cl::init(false)}; private: friend std::unique_ptr<::mlir::Pass> createTTNNOptimizer() { @@ -176,12 +182,12 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { RankedTensorType tensorType = mlir::cast(op->getResult(0).getType()); - LegalGridAnalysis legalGridAnalysis = - getChildAnalysis(op); - legalGridAnalysis.init(LegalGridAnalysisInput(chipDesc, max_grid, - tensorType, maxLegalLayouts, - &overrideOutputLayout)); - legalLayouts[op] = legalGridAnalysis.getResult(); + LegalLayoutAnalysis legalLayoutAnalysis = + getChildAnalysis(op); + legalLayoutAnalysis.init(LegalLayoutAnalysisInput( + chipDesc, max_grid, tensorType, maxLegalLayouts, + &overrideOutputLayout, rowMajorEnabled)); + legalLayouts[op] = legalLayoutAnalysis.getResult(); }); llvm::DenseMap> opSchedule; diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp index 9c8ef2be1f..ad59ea91cb 100644 --- a/lib/Dialect/TTNN/Utils/PassOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -6,23 +6,34 @@ namespace mlir::tt::ttnn { +namespace { +std::optional> +parseGrid(StringRef param, char gridSeparator, llvm::cl::Option &opt) { + SmallVector gridParts; + param.split(gridParts, gridSeparator); + if (gridParts.size() == 2) { + int64_t gridX, gridY; + if (gridParts[0].getAsInteger(10, gridX) || + gridParts[1].getAsInteger(10, gridY)) { + opt.error("Invalid grid size: " + param); + return std::nullopt; + } + return SmallVector{gridX, gridY}; + } + return std::nullopt; +} +} // namespace + bool OutputLayoutOverrideParser::parse( llvm::cl::Option &opt, StringRef argName, StringRef arg, llvm::StringMap &value) { SmallVector opOverrideList; - constexpr size_t kMaxGridSize = 2; constexpr size_t kvPairSize = 2; - constexpr size_t kMaxLayoutOverrideParams = 5; constexpr size_t iOpName = 0; constexpr size_t iLayoutOverrideParams = 1; - constexpr size_t iGrid = 0; - constexpr size_t iMemorySpace = 1; - constexpr size_t iTensorMemoryLayout = 2; - constexpr size_t iMemoryLayout = 3; - constexpr size_t iDataType = 4; constexpr char opSeparator = ','; constexpr char opNameSeparator = '='; - constexpr char paramSepataor = ':'; + constexpr char paramSeparator = ':'; constexpr char gridSeparator = 'x'; arg.split(opOverrideList, opSeparator); @@ -34,66 +45,51 @@ bool OutputLayoutOverrideParser::parse( return true; } - SmallVector layoutParamParts; - // Split into layout parameters. + SmallVector layoutParamParts; opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, - paramSepataor); - if (layoutParamParts.size() != kMaxLayoutOverrideParams) { - opt.error("Invalid number of layout parameters: " + - std::to_string(layoutParamParts.size())); - return true; - } - - // Parse grid. - SmallVector grid; - SmallVector gridParts; - layoutParamParts[iGrid].split(gridParts, gridSeparator); - for (const StringRef gridPart : gridParts) { - int64_t gridValue; - if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { - opt.error("Invalid grid size: " + gridPart); + paramSeparator); + + OutputLayoutOverrideParams params; + + for (const StringRef ¶m : layoutParamParts) { + if (auto grid = parseGrid(param, gridSeparator, opt)) { + if (params.grid.has_value()) { + opt.error("Multiple grid parameters provided: " + param); + return true; + } + params.grid = grid; + } else if (auto bufferType = symbolizeBufferType(param)) { + if (params.bufferType.has_value()) { + opt.error("Multiple buffer type parameters provided: " + param); + return true; + } + params.bufferType = bufferType; + } else if (auto tensorMemoryLayout = symbolizeTensorMemoryLayout(param)) { + if (params.tensorMemoryLayout.has_value()) { + opt.error("Multiple tensor memory layout parameters provided: " + + param); + return true; + } + params.tensorMemoryLayout = tensorMemoryLayout; + } else if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(param)) { + if (params.memoryLayout.has_value()) { + opt.error("Multiple memory layout parameters provided: " + param); + return true; + } + params.memoryLayout = memoryLayout; + } else if (auto dataType = mlir::tt::DataTypeStringToEnum(param)) { + if (params.dataType.has_value()) { + opt.error("Multiple data type parameters provided: " + param); + return true; + } + params.dataType = dataType; + } else { + opt.error("Invalid layout parameter: " + param); return true; } - grid.push_back(gridValue); - } - - // Parse memory space. - std::optional bufferType = - symbolizeBufferType(layoutParamParts[iMemorySpace]); - if (!bufferType.has_value()) { - opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); - return true; - } - - // Parse tensor memory layout. - std::optional tensorMemoryLayout = - symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); - if (!tensorMemoryLayout.has_value()) { - opt.error("Invalid tensor memory layout: " + - layoutParamParts[iTensorMemoryLayout]); - return true; } - // Parse memory layout. - std::optional memoryLayout = - mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); - if (!memoryLayout.has_value()) { - opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); - return true; - } - - // Parse data type. - std::optional dataType = - mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); - if (!dataType.has_value()) { - opt.error("Invalid data type: " + layoutParamParts[iDataType]); - return true; - } - - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ - std::move(grid), bufferType.value(), tensorMemoryLayout.value(), - memoryLayout.value(), dataType.value()}; + value[opOverrideParts[iOpName]] = params; } return false; } @@ -105,21 +101,33 @@ std::string OutputLayoutOverrideParser::toString( for (const auto &entry : value) { res += std::string(entry.getKey()) + "="; const OutputLayoutOverrideParams ¶ms = entry.getValue(); + // Print grid values - for (size_t i = 0; i < params.grid.size(); ++i) { - res += std::to_string(params.grid[i]); - if (i < params.grid.size() - 1) { - res += "x"; + if (params.grid.has_value()) { + for (size_t i = 0; i < params.grid.value().size(); ++i) { + res += std::to_string(params.grid.value()[i]); + if (i < params.grid.value().size() - 1) { + res += "x"; + } } } // Print memory space and memory layout - res += ":" + - std::string(mlir::tt::ttnn::stringifyBufferType(params.bufferType)); - res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( - params.tensorMemoryLayout)); - res += - ":" + std::string(mlir::tt::ttnn::stringifyLayout(params.memoryLayout)); - res += ":" + std::string(mlir::tt::DataTypeEnumToString(params.dataType)); + if (params.bufferType.has_value()) { + res += ":" + std::string(mlir::tt::ttnn::stringifyBufferType( + params.bufferType.value())); + } + if (params.tensorMemoryLayout.has_value()) { + res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( + params.tensorMemoryLayout.value())); + } + if (params.memoryLayout.has_value()) { + res += ":" + std::string(mlir::tt::ttnn::stringifyLayout( + params.memoryLayout.value())); + } + if (params.dataType.has_value()) { + res += ":" + std::string( + mlir::tt::DataTypeEnumToString(params.dataType.value())); + } if (++count < value.size()) { res += ","; } diff --git a/python/OptimizerOverrides.cpp b/python/OptimizerOverrides.cpp index b41d2081d2..bd5ce94f43 100644 --- a/python/OptimizerOverrides.cpp +++ b/python/OptimizerOverrides.cpp @@ -131,13 +131,13 @@ void populateOptimizerOverridesModule(py::module &m) { "grid", [](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) { // Getter: Convert SmallVector to std::vector - return std::vector(obj.grid.begin(), obj.grid.end()); + return std::vector(obj.grid->begin(), obj.grid->end()); }, [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, const std::vector &input) { // Setter: Convert std::vector to SmallVector - obj.grid.clear(); - obj.grid.append(input.begin(), input.end()); + obj.grid->clear(); + obj.grid->append(input.begin(), input.end()); }) .def_readwrite("buffer_type", &mlir::tt::ttnn::OutputLayoutOverrideParams::bufferType) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc_input_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir similarity index 86% rename from test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc_input_layout_override.mlir rename to test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir index b492a54c13..ec03a6ad59 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc_input_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir @@ -4,13 +4,14 @@ module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #l1_>, > + // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) - // CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - // CHECK: %{{.*}} = "ttnn.to_memory_config"(%[[C]]) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + // CHECK: %[[IDX:.*]] = "ttnn.to_memory_config"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + // CHECK: %{{.*}} = "ttnn.add"(%[[IDX]]{{.*}} %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir index 70ebaddb8d..5c34fe8548 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir @@ -3,25 +3,24 @@ module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<64x96xbf16>, %arg3: tensor<96x32xbf16>, %arg4: tensor<64x32xbf16>) -> tensor<64x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > - // CHECK: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > + // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > %0 = tensor.empty() : tensor<64x96xbf16> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %2 = tensor.empty() : tensor<64x96xbf16> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %4 = tensor.empty() : tensor<64x96xbf16> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %6 = tensor.empty() : tensor<64x32xbf16> - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_10]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> %7 = "ttir.matmul"(%5, %arg3, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<96x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %8 = tensor.empty() : tensor<64x32xbf16> - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_10]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %10 = tensor.empty() : tensor<64x32xbf16> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_10]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %11 : tensor<64x32xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir index ca0ec90e6f..67c480d8c9 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir @@ -22,9 +22,8 @@ module attributes {} { func.func @forward(%arg0: tensor<64x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<64x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x8xbf16, #dram>, > - // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x4xbf16, #l1_>, > - // CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x8xbf16, #l1_>, > + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #dram>, > + // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_3]]> %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> @@ -37,7 +36,7 @@ module attributes {} { %8 = tensor.empty() : tensor<64x32xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> %9 = "ttir.matmul"(%3, %7, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %9 : tensor<64x32xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir index a4cee76569..f45c11c624 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir @@ -3,26 +3,24 @@ #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > - // CHECK: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > + // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) - // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_6]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) - // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_7]]> + // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir index 74a2dc55c7..e5a4f3fa66 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir @@ -13,7 +13,7 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8192x8192xbf16>, %arg1: tensor<8192x8192xbf16>, %arg2: tensor<8192x8192xbf16>, %arg3: tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> { - // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<1024x1024xbf16, #dram>, + // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<{{.*}}>, #dram>, > %0 = tensor.empty() : tensor<8192x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir index 7b5f069640..ceca628400 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir @@ -14,14 +14,14 @@ module attributes {} { func.func @forward(%arg0: tensor<5120x4096xbf16>, %arg1: tensor<5120x4096xbf16>, %arg2: tensor<4096x5120xbf16>, %arg3: tensor<4096x5120xbf16>) -> tensor<5120x5120xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<512x640xbf16, #dram>, - // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<640x512xbf16, #dram>, - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<640x640xbf16, #l1_>, + // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x4096xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_6]]> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_4]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> %2 = tensor.empty() : tensor<4096x5120xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_4]]> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_6]]> %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> %4 = tensor.empty() : tensor<5120x5120xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_7]]> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir index edc2182a73..74675e4e0b 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir @@ -14,8 +14,8 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<4096x5120xbf16>, %arg2: tensor<5120x5120xbf16>, %arg3: tensor<5120x5120xbf16>) -> tensor<4096x5120xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<512x640xbf16, #dram>, - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<640x640xbf16, #l1_>, + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir index b5715b5a13..c3cd2740bc 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir @@ -14,8 +14,8 @@ module attributes {} { func.func @forward(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x8192xbf16>, %arg3: tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<256x256xbf16, #dram>, - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<256x1024xbf16, #l1_>, + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<2048x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir index 43a2c1d8da..c9cd33f1c9 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir @@ -14,8 +14,8 @@ module attributes {} { func.func @forward(%arg0: tensor<5120x5120xbf16>, %arg1: tensor<5120x5120xbf16>, %arg2: tensor<5120x4096xbf16>, %arg3: tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<640x512xbf16, #dram>, - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<640x640xbf16, #l1_>, + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir index f32f6a5afe..760ea2b8a5 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir @@ -14,8 +14,8 @@ module attributes {} { func.func @forward(%arg0: tensor<8192x2048xbf16>, %arg1: tensor<8192x2048xbf16>, %arg2: tensor<2048x2048xbf16>, %arg3: tensor<2048x2048xbf16>) -> tensor<8192x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<256x256xbf16, #dram>, - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<1024x256xbf16, #l1_>, + // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<8192x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir index 4c3358368e..5d95a6204a 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir @@ -14,9 +14,9 @@ module attributes {} { func.func @forward(%arg0: tensor<2048x8192xbf16>, %arg1: tensor<2048x8192xbf16>, %arg2: tensor<8192x2048xbf16>, %arg3: tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<256x1024xbf16, #l1_>, - // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<1024x256xbf16, #l1_>, - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<256x256xbf16, #dram>, + // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<2048x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_4]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir index bf441ffbea..75b876dbf3 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>, %arg3: tensor<32x32xbf16>) -> tensor<32x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<4x4xbf16, #l1_>, + // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<32x32xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir index 80c44648a7..8e25f97ca0 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir @@ -3,15 +3,15 @@ #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8>, memref<4x4xf32, #dram>, > + // CHECK: #[[LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #ttnn_layout>, tensor<1x32x32xf32, #ttnn_layout> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc_output_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir similarity index 86% rename from test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc_output_layout_override.mlir rename to test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir index a43c21ab61..79bbae2753 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc_output_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir @@ -7,15 +7,15 @@ module attributes {} { // CHECK: #[[LAYOUT_0:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #system_memory>> // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <4x4>, memref<8x8xbf16, #dram>, > // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <4x4>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > - // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8>, memref<4x4xf32, #dram>, > + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<{{.*}} #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_3]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_3]]> %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #[[LAYOUT_0]]>, tensor<1x32x32xf32, #[[LAYOUT_0]]> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir new file mode 100644 index 0000000000..c1e79c7a07 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir @@ -0,0 +1,20 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true override-output-layout=add_1=row_major" %s | FileCheck %s +#any_device = #tt.operand_constraint +#loc = loc("test_ops.py:17_0_0":0:0) +module attributes {} { + func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { + // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<{{.*}}memref<4x4xf32{{.*}} + %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> + %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + return %1, %3 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) + } loc(#loc) +} loc(#loc) +#loc1 = loc("test_ops.py:17_0_0":0:4) +#loc2 = loc("test_ops.py:17_0_0":0:6) +#loc3 = loc("test_ops.py:17_0_0":0:3) +#loc4 = loc(unknown) +#loc5 = loc("add_1"(#loc1)) +#loc6 = loc("add_2"(#loc2)) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/test_grid_set.mlir b/test/ttmlir/Dialect/TTNN/optimizer/test_grid_set.mlir deleted file mode 100644 index 814cd0c459..0000000000 --- a/test/ttmlir/Dialect/TTNN/optimizer/test_grid_set.mlir +++ /dev/null @@ -1,20 +0,0 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttnn-optimizer %s | FileCheck %s -#device = #tt.device (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]> -#dram = #ttnn.buffer_type -#system_memory = #ttnn.buffer_type -#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system_memory>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, > -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, > -module attributes {tt.device = #device} { - func.func @forward(%arg0: tensor<64x128xf32, #ttnn_layout>, %arg1: tensor<64x128xf32, #ttnn_layout>) -> tensor<64x128xf32, #ttnn_layout> { - %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<64x128>>, >}> : (tensor<64x128xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> - %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<64x128>>, >}> : (tensor<64x128xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout1> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<64x128>>, >, shape = #ttnn.shape<64x128>}> : (!tt.device<#device>) -> tensor<64x128xf32, #ttnn_layout2> - %4 = "ttnn.multiply"(%1, %2, %3) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout1>, tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout2> - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #dram>, > - // CHECK: %{{.+}} = "ttnn.multiply"{{.+}} -> tensor<64x128xf32, #[[LAYOUT_2]]> - %5 = "ttnn.to_layout"(%4) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<64x128>>>}> : (tensor<64x128xf32, #ttnn_layout2>) -> tensor<64x128xf32, #ttnn_layout> - return %5 : tensor<64x128xf32, #ttnn_layout> - } -} diff --git a/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir b/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir index 16986408cd..08e6da1165 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/test_override_reshard_edges.mlir @@ -7,8 +7,8 @@ module attributes {tt.device = #device} { func.func @main(%arg0: tensor<1x32x32xf32, #ttnn_layout>, %arg1: tensor<1x32x32xf32, #ttnn_layout>, %arg2: tensor<1x32x32xf32, #ttnn_layout>) -> tensor<1x32x32xf32, #ttnn_layout> { // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #l1_>, > - // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8>, memref<4x4xf32, #dram>, > + // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> %2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<, <<32x32>>, >}> : (tensor<1x32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x32x32xf32, #ttnn_layout1> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir b/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir index 725d7b83f6..c12fc07714 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir @@ -2,10 +2,9 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8>, memref<8x16xf32, #dram>, > - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %{{.*}} = "ttnn.empty"{{.*}} %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_2]]> + // CHECK: %{{.*}} = "ttnn.multiply"{{.*}} %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/mnist_sharding.mlir b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir similarity index 73% rename from test/ttmlir/Dialect/TTNN/optimizer/mnist_sharding.mlir rename to test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir index 55f3a60548..3cf9c45817 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/mnist_sharding.mlir +++ b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir @@ -1,24 +1,26 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true" %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn #any_device = #tt.operand_constraint #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { - // CHECK: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8>, memref<1x32xf32, #l1_>, > - // CHECK: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x10xf32, #l1_>, > + // CHECK-DAG: #[[LAYOUT_10:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) - // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> + // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) - // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_11]]> + // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_11]]> %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_11]]> + // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_11]]> %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) diff --git a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding_tiled.mlir b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding_tiled.mlir deleted file mode 100644 index cf5a5b9553..0000000000 --- a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding_tiled.mlir +++ /dev/null @@ -1,42 +0,0 @@ -// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#loc = loc("MNISTLinear":4294967295:0) -module @"tt-forge-graph" attributes {} { - func.func @main(%arg0: tensor<32x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<32xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x32xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<32x32xf32> { - // CHECK-DAG: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x8>, memref<32x32xf32, #l1_>, > - // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<32x32xf32, #l1_>, > - %0 = tensor.empty() : tensor<32x256xf32> loc(#loc8) - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]> - %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x784xf32>, tensor<784x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc8) - %2 = tensor.empty() : tensor<32x256xf32> loc(#loc9) - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]> - %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x256xf32>, tensor<256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc9) - %4 = tensor.empty() : tensor<32x256xf32> loc(#loc10) - // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<32x256xf32, #[[LAYOUT_1]]> - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x256xf32>, tensor<32x256xf32>) -> tensor<32x256xf32> loc(#loc10) - %6 = tensor.empty() : tensor<32x32xf32> loc(#loc11) - // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<32x32xf32, #[[LAYOUT_2]]> - %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x256xf32>, tensor<256x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc11) - %8 = tensor.empty() : tensor<32x32xf32> loc(#loc12) - // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xf32, #[[LAYOUT_2]]> - %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc12) - %10 = tensor.empty() : tensor<32x32xf32> loc(#loc13) - %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> loc(#loc13) - return %11 : tensor<32x32xf32> loc(#loc7) - } loc(#loc) -} loc(#loc) -#loc1 = loc("MNISTLinear":4294967295:10) -#loc2 = loc("MNISTLinear":4294967295:8) -#loc3 = loc("MNISTLinear":4294967295:6) -#loc4 = loc("MNISTLinear":4294967295:4) -#loc5 = loc("MNISTLinear":4294967295:3) -#loc6 = loc("MNISTLinear":4294967295:2) -#loc7 = loc(unknown) -#loc8 = loc("matmul_1"(#loc1)) -#loc9 = loc("add_2"(#loc2)) -#loc10 = loc("relu_3"(#loc3)) -#loc11 = loc("matmul_5"(#loc4)) -#loc12 = loc("add_6"(#loc5)) -#loc13 = loc("softmax_7"(#loc6)) diff --git a/test/unittests/Optimizer/TestOptimizerOverrides.cpp b/test/unittests/Optimizer/TestOptimizerOverrides.cpp index c75fde21f9..31118262f5 100644 --- a/test/unittests/Optimizer/TestOptimizerOverrides.cpp +++ b/test/unittests/Optimizer/TestOptimizerOverrides.cpp @@ -2,13 +2,128 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "llvm/Support/CommandLine.h" #include -#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" using namespace mlir::tt::ttnn; -class TestOptimizerOverrides : public ::testing::Test { +class OutputLayoutOverrideTest : public ::testing::Test { +protected: + llvm::cl::opt OverrideOutputLayoutOption{ + "override-output-layout"}; + OutputLayoutOverrideParser parser{OverrideOutputLayoutOption}; + llvm::StringMap parsedOverride; +}; + +TEST_F(OutputLayoutOverrideTest, ParseFullOutputLayoutOverride) { + std::string arg = "op1=2x2:dram:interleaved:tile:f32"; + + bool result = parser.parse(OverrideOutputLayoutOption, + "override-output-layout", arg, parsedOverride); + ASSERT_FALSE(result); + ASSERT_EQ(parsedOverride.size(), 1); + ASSERT_TRUE(parsedOverride.count("op1")); + + const auto ¶ms = parsedOverride["op1"]; + ASSERT_TRUE(params.grid.has_value()); + ASSERT_EQ(params.grid->size(), 2); + ASSERT_EQ((*params.grid)[0], 2); + ASSERT_EQ((*params.grid)[1], 2); + ASSERT_TRUE(params.bufferType.has_value()); + ASSERT_EQ(params.bufferType.value(), BufferType::DRAM); + ASSERT_TRUE(params.tensorMemoryLayout.has_value()); + ASSERT_EQ(params.tensorMemoryLayout.value(), TensorMemoryLayout::Interleaved); + ASSERT_TRUE(params.memoryLayout.has_value()); + ASSERT_EQ(params.memoryLayout.value(), Layout::Tile); + ASSERT_TRUE(params.dataType.has_value()); + ASSERT_EQ(params.dataType.value(), mlir::tt::DataType::Float32); +} + +TEST_F(OutputLayoutOverrideTest, ParsePartialOutputLayoutOverride) { + std::string arg = "op1=2x2:block_sharded"; + + bool result = parser.parse(OverrideOutputLayoutOption, + "override-output-layout", arg, parsedOverride); + ASSERT_FALSE(result); + ASSERT_EQ(parsedOverride.size(), 1); + ASSERT_TRUE(parsedOverride.count("op1")); + + const auto ¶ms = parsedOverride["op1"]; + ASSERT_TRUE(params.grid.has_value()); + ASSERT_EQ(params.grid->size(), 2); + ASSERT_EQ((*params.grid)[0], 2); + ASSERT_EQ((*params.grid)[1], 2); + ASSERT_FALSE(params.bufferType.has_value()); + ASSERT_TRUE(params.tensorMemoryLayout.has_value()); + ASSERT_EQ(params.tensorMemoryLayout.value(), + TensorMemoryLayout::BlockSharded); + ASSERT_FALSE(params.memoryLayout.has_value()); + ASSERT_FALSE(params.dataType.has_value()); +} + +TEST_F(OutputLayoutOverrideTest, ParseInvalidOutputLayoutOverride) { + std::string arg = "op1=invalid_value"; + + bool result = parser.parse(OverrideOutputLayoutOption, + "override-output-layout", arg, parsedOverride); + ASSERT_TRUE(result); +} + +TEST_F(OutputLayoutOverrideTest, ParseMultipleInstancesOfSameParameter) { + std::string arg = "op1=2x2:2x2"; + + bool result = parser.parse(OverrideOutputLayoutOption, + "override-output-layout", arg, parsedOverride); + ASSERT_TRUE(result); +} + +TEST_F(OutputLayoutOverrideTest, ParseMultipleOps) { + std::string arg = "op1=2x2:dram:interleaved:tile:f32,op2=4x4:l1:block_" + "sharded:row_major:f16"; + + bool result = parser.parse(OverrideOutputLayoutOption, + "override-output-layout", arg, parsedOverride); + ASSERT_FALSE(result); + ASSERT_EQ(parsedOverride.size(), 2); + ASSERT_TRUE(parsedOverride.count("op1")); + ASSERT_TRUE(parsedOverride.count("op2")); + + const auto ¶ms1 = parsedOverride["op1"]; + ASSERT_TRUE(params1.grid.has_value()); + ASSERT_EQ(params1.grid->size(), 2); + ASSERT_EQ((*params1.grid)[0], 2); + ASSERT_EQ((*params1.grid)[1], 2); + ASSERT_TRUE(params1.bufferType.has_value()); + ASSERT_EQ(params1.bufferType.value(), BufferType::DRAM); + ASSERT_TRUE(params1.tensorMemoryLayout.has_value()); + ASSERT_EQ(params1.tensorMemoryLayout.value(), + TensorMemoryLayout::Interleaved); + ASSERT_TRUE(params1.memoryLayout.has_value()); + ASSERT_EQ(params1.memoryLayout.value(), Layout::Tile); + ASSERT_TRUE(params1.dataType.has_value()); + ASSERT_EQ(params1.dataType.value(), mlir::tt::DataType::Float32); + + const auto ¶ms2 = parsedOverride["op2"]; + ASSERT_TRUE(params2.grid.has_value()); + ASSERT_EQ(params2.grid->size(), 2); + ASSERT_EQ((*params2.grid)[0], 4); + ASSERT_EQ((*params2.grid)[1], 4); + ASSERT_TRUE(params2.bufferType.has_value()); + ASSERT_EQ(params2.bufferType.value(), BufferType::L1); + ASSERT_TRUE(params2.tensorMemoryLayout.has_value()); + ASSERT_EQ(params2.tensorMemoryLayout.value(), + TensorMemoryLayout::BlockSharded); + ASSERT_TRUE(params2.memoryLayout.has_value()); + ASSERT_EQ(params2.memoryLayout.value(), Layout::RowMajor); + ASSERT_TRUE(params2.dataType.has_value()); + ASSERT_EQ(params2.dataType.value(), mlir::tt::DataType::Float16); +} + +class TestOptimizerOverrideHandler : public ::testing::Test { public: OptimizerOverridesHandler optimizerOverridesHandler; @@ -73,8 +188,7 @@ class TestOptimizerOverrides : public ::testing::Test { // - tensor memory layout interleaved // - memory layout tile // - data type fp16. - outputLayoutOverrideParams.grid.push_back(2); - outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.grid = llvm::SmallVector({2, 2}); outputLayoutOverrideParams.bufferType = BufferType::DRAM; outputLayoutOverrideParams.tensorMemoryLayout = TensorMemoryLayout::Interleaved; @@ -102,8 +216,7 @@ class TestOptimizerOverrides : public ::testing::Test { // - tensor memory layout block_sharded // - memory layout row_major // - data type fp16. - outputLayoutOverrideParams.grid.push_back(8); - outputLayoutOverrideParams.grid.push_back(4); + outputLayoutOverrideParams.grid = llvm::SmallVector({8, 4}); outputLayoutOverrideParams.bufferType = BufferType::L1; outputLayoutOverrideParams.tensorMemoryLayout = TensorMemoryLayout::BlockSharded; @@ -131,8 +244,7 @@ class TestOptimizerOverrides : public ::testing::Test { // - tensor memory layout height_sharded // - memory layout tile // - data type fp16. - outputLayoutOverrideParams.grid.push_back(3); - outputLayoutOverrideParams.grid.push_back(6); + outputLayoutOverrideParams.grid = llvm::SmallVector({3, 6}); outputLayoutOverrideParams.bufferType = BufferType::SystemMemory; outputLayoutOverrideParams.tensorMemoryLayout = TensorMemoryLayout::HeightSharded; @@ -196,7 +308,7 @@ class TestOptimizerOverrides : public ::testing::Test { }; // Test the setEnableOptimizer method -TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) { +TEST_F(TestOptimizerOverrideHandler, TestSetOptimizerPass) { optimizerOverridesHandler.setEnableOptimizer(true); ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer()); @@ -206,7 +318,7 @@ TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) { } // Test the setMemoryConfig method -TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) { +TEST_F(TestOptimizerOverrideHandler, TestSetMemoryConfig) { optimizerOverridesHandler.setMemoryReconfig(true); ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig()); @@ -216,7 +328,7 @@ TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) { } // Test the setMemoryLayoutAnalysis method -TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysis) { +TEST_F(TestOptimizerOverrideHandler, TestSetMemoryLayoutAnalysis) { optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); @@ -226,7 +338,7 @@ TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysis) { } // Test the setEnableMemoryLayoutAnalysisPolicy method -TEST_F(TestOptimizerOverrides, TestSetEnableMemoryLayoutAnalysisPolicy) { +TEST_F(TestOptimizerOverrideHandler, TestSetEnableMemoryLayoutAnalysisPolicy) { optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(true); ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); @@ -236,7 +348,7 @@ TEST_F(TestOptimizerOverrides, TestSetEnableMemoryLayoutAnalysisPolicy) { } // Test the setMemoryLayoutAnalysisPolicy method -TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysisPolicy) { +TEST_F(TestOptimizerOverrideHandler, TestSetMemoryLayoutAnalysisPolicy) { optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); @@ -250,7 +362,7 @@ TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysisPolicy) { } // Test the setInputLayoutOverrides method -TEST_F(TestOptimizerOverrides, TestSetInputLayoutOverrides) { +TEST_F(TestOptimizerOverrideHandler, TestSetInputLayoutOverrides) { llvm::StringMap inputLayoutOverrides = createInputLayoutOverrides(); @@ -262,7 +374,7 @@ TEST_F(TestOptimizerOverrides, TestSetInputLayoutOverrides) { } // Test the setOutputLayoutOverrides method -TEST_F(TestOptimizerOverrides, TestSetOutputLayoutOverrides) { +TEST_F(TestOptimizerOverrideHandler, TestSetOutputLayoutOverrides) { llvm::StringMap outputLayoutOverrides = createOutputLayoutOverrides(); @@ -274,7 +386,7 @@ TEST_F(TestOptimizerOverrides, TestSetOutputLayoutOverrides) { } // Test the addInputLayoutOverride method passing the whole object -TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideObject) { +TEST_F(TestOptimizerOverrideHandler, TestAddInputLayoutOverrideObject) { // This method is implemented across two functions in the // OptimizerOverridesHandler class. The first function takes the whole object @@ -299,7 +411,7 @@ TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideObject) { } // Test the addInputLayoutOverride method passing the individual parameters -TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideParams) { +TEST_F(TestOptimizerOverrideHandler, TestAddInputLayoutOverrideParams) { // This method is implemented across two functions in the // OptimizerOverridesHandler class. The first function takes the whole object @@ -324,7 +436,7 @@ TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideParams) { } // Test the addOutputLayoutOverride method passing the whole object -TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideObject) { +TEST_F(TestOptimizerOverrideHandler, TestAddOutputLayoutOverrideObject) { // This method is implemented across two functions in the // OptimizerOverridesHandler class. The first function takes the whole object @@ -349,7 +461,7 @@ TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideObject) { } // Test the addOutputLayoutOverride method passing the individual parameters -TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideParams) { +TEST_F(TestOptimizerOverrideHandler, TestAddOutputLayoutOverrideParams) { // This method is implemented across two functions in the // OptimizerOverridesHandler class. The first function takes the whole object @@ -381,21 +493,21 @@ TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideParams) { } // Test the setSystemDescPath method -TEST_F(TestOptimizerOverrides, TestSetSystemDescPath) { +TEST_F(TestOptimizerOverrideHandler, TestSetSystemDescPath) { optimizerOverridesHandler.setSystemDescPath("system_desc_path"); ASSERT_EQ(optimizerOverridesHandler.getSystemDescPath(), "system_desc_path"); } // Test the setMaxLegalLayouts method -TEST_F(TestOptimizerOverrides, TestSetMaxLegalLayouts) { +TEST_F(TestOptimizerOverrideHandler, TestSetMaxLegalLayouts) { optimizerOverridesHandler.setMaxLegalLayouts(10); ASSERT_EQ(optimizerOverridesHandler.getMaxLegalLayouts(), 10); } // Test the setMeshShape method -TEST_F(TestOptimizerOverrides, TestSetMeshShape) { +TEST_F(TestOptimizerOverrideHandler, TestSetMeshShape) { std::vector meshShape; meshShape.push_back(1); @@ -407,7 +519,7 @@ TEST_F(TestOptimizerOverrides, TestSetMeshShape) { } // Test the toString method -TEST_F(TestOptimizerOverrides, TestToString) { +TEST_F(TestOptimizerOverrideHandler, TestToString) { std::string options; options += diff --git a/tools/explorer/test/run_tests.py b/tools/explorer/test/run_tests.py index ceff14ae0a..91167e86a7 100644 --- a/tools/explorer/test/run_tests.py +++ b/tools/explorer/test/run_tests.py @@ -17,7 +17,7 @@ "tools/explorer/test/models/*.mlir", ] TEST_EXECUTE_MODEL_PATHS = [ - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding_tiled.mlir", + "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", ] @@ -97,14 +97,14 @@ def test_execute_model(model_path): def test_execute_mnist_l1_interleaved(): execute_command( - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding_tiled.mlir", + "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", {"optimizationPolicy": "L1 Interleaved"}, ) def test_execute_mnist_optimizer_disabled(): execute_command( - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding_tiled.mlir", + "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", {"optimizationPolicy": "Optimizer Disabled"}, )