From 36a886a97e069e3f6dbe9c4f8ac45a81e0c9afae Mon Sep 17 00:00:00 2001 From: Vlad Roubtsov Date: Thu, 27 Feb 2025 23:00:56 +0000 Subject: [PATCH 1/6] simplify DPS inheritance using utils::getDpsOutputs() --- include/ttmlir/Dialect/TTIR/IR/TTIROps.h | 2 + include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 154 ++-------------------- include/ttmlir/Utils.h | 43 ++++++ 3 files changed, 56 insertions(+), 143 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index e876935500..62239a5000 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -9,6 +9,8 @@ #include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h" #include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h" +#include "ttmlir/Utils.h" + #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 6f97a2dfcb..baa419c9cf 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -23,7 +23,8 @@ include "mlir/IR/OpBase.td" class TTIR_DPSOp traits = []> : TTIR_Op { let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } + // base implementation that will detect getOutputMutable() or getOutputsMutable() + MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } }]; } @@ -91,10 +92,7 @@ def TTIR_GenericOp : TTIR_BufferizableOp<"generic", [AttrSizedOperandSegments, N grid = #tt.grid<1x1>, // The grid range of cores to dispatch work to. indexing_maps = [#map, #map, #map], // Affine maps for indexing into the input/output tensors. See linalg.generic iterator_types = [#parallel, #parallel], // Iterator types for the input/output tensors. See linalg.generic - operandSegmentSizes = array, // Sizes of the operand segments, i.e. 2 inputs, 1 cb and 1 output. - operand_cb_mapping = array, // Mapping of input & output operands to cbs. -1 means no mapping. - // Mapped operands correspond to buffers in streaming mode. - // Non-mapped operands correspond to buffers in alias mode. + operandSegmentSizes = array // Sizes of the operand segments, i.e. 2 inputs and 1 output. ({ ^bb0(%arg2: tensor<64x128xf32, #tt.buffer, alias>>, %arg3: tensor<64x128xf32, #tt.buffer, stream>>, @@ -105,29 +103,25 @@ def TTIR_GenericOp : TTIR_BufferizableOp<"generic", [AttrSizedOperandSegments, N }]; let arguments = (ins Variadic:$inputs, - Variadic:$cbs, + Variadic:$cbs, Variadic:$outputs, TT_GridAttr:$grid, AffineMapArrayAttr:$indexing_maps, TT_IteratorTypeArrayAttr:$iterator_types, - DefaultValuedOptionalAttr:$operand_cb_mapping); // index of input operand and index of cb go together + DefaultValuedOptionalAttr:$operand_cb_mapping); let results = (outs Variadic:$results); let regions = (region VariadicRegion:$regions); let hasVerifier = 1; let extraClassDeclaration = [{ - // For a given block argument index, return the corresponding operand of the surrounding generic op. - // This is needed because extra CB operands may be present in between the inputs and outputs. + MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + Value getMatchingOperand(size_t blockArgIndex) { assert(blockArgIndex < getInputs().size() + getOutputs().size() && "blockArgIndex should be within the range of inputs and outputs"); return blockArgIndex < getInputs().size() ? getOperand(blockArgIndex) : getOperand(blockArgIndex + getCbs().size()); } - - MutableOperandRange getDpsInitsMutable() { - return getOutputsMutable(); - } }]; } @@ -153,7 +147,7 @@ def TTIR_ToLayoutOp : TTIR_BufferizableOp<"to_layout"> { let results = (outs Variadic:$results); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } struct CompoundComponents { bool isLayoutChange = false; @@ -796,8 +790,7 @@ class TTIR_ReductionOp traits = []> : let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - + MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); // Returns the indexing maps and iterator types for the reduction op. @@ -984,10 +977,6 @@ def TTIR_EmbeddingOp : TTIR_NamedOp<"embedding"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1004,10 +993,6 @@ def TTIR_EmbeddingBackwardOp : TTIR_NamedOp<"embedding_backward"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1035,10 +1020,6 @@ def TTIR_CumSumOp : TTIR_NamedOp<"cumsum"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1054,10 +1035,6 @@ def TTIR_SoftmaxOp : TTIR_NamedOp<"softmax"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1074,10 +1051,6 @@ def TTIR_TransposeOp : TTIR_NamedOp<"transpose"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasCanonicalizer = 1; @@ -1095,10 +1068,6 @@ def TTIR_ConcatOp : TTIR_NamedOp<"concat"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1130,10 +1099,6 @@ def TTIR_RepeatOp : TTIR_NamedOp<"repeat"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1157,10 +1122,6 @@ def TTIR_RepeatInterleaveOp : TTIR_NamedOp<"repeat_interleave"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1223,10 +1184,6 @@ def TTIR_BroadcastOp : TTIR_NamedOp<"broadcast"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasFolder = 1; @@ -1296,10 +1253,6 @@ def TTIR_Conv2dOp : TTIR_NamedOp<"conv2d"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1361,10 +1314,6 @@ def TTIR_ConvTranspose2dOp : TTIR_NamedOp<"conv_transpose2d"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1395,10 +1344,6 @@ def TTIR_ConvolutionOp : TTIR_NamedOp<"convolution"> { let results = (outs AnyRankedTensor); let hasVerifier = 1; - - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; } def TTIR_GatherOp: TTIR_NamedOp<"gather"> { @@ -1419,9 +1364,6 @@ def TTIR_GatherOp: TTIR_NamedOp<"gather"> { DenseI64ArrayAttr:$slice_sizes, BoolAttr:$indices_are_sorted); let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; } def TTIR_PoolingOp : TTIR_NamedOp<"pooling", [AttrSizedOperandSegments]> { @@ -1468,10 +1410,6 @@ def TTIR_MaxPool2dOp : TTIR_NamedOp<"max_pool2d"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1487,10 +1425,6 @@ def TTIR_ReshapeOp: TTIR_NamedOp<"reshape"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasFolder = 1; @@ -1547,10 +1481,6 @@ def TTIR_SliceOp: TTIR_NamedOp<"slice"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1573,10 +1503,6 @@ def TTIR_SelectOp: TTIR_NamedOp<"select"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1598,10 +1524,6 @@ def TTIR_IndexOp: TTIR_NamedOp<"index"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } // ANCHOR_END: decomposing_an_op_index_ttir @@ -1618,10 +1540,6 @@ def TTIR_SqueezeOp : TTIR_NamedOp<"squeeze"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1637,10 +1555,6 @@ def TTIR_UnsqueezeOp : TTIR_NamedOp<"unsqueeze"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1663,10 +1577,6 @@ def TTIR_ClampOp : TTIR_NamedOp<"clamp"> { F32Attr:$min, F32Attr:$max); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let results = (outs AnyRankedTensor:$result); let hasVerifier = 1; @@ -1775,10 +1685,6 @@ def TTIR_ReverseOp : TTIR_NamedOp<"reverse", [AllShapesMatch<["input", "result"] let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasCanonicalizer = 1; @@ -1820,10 +1726,6 @@ def TTIR_FillOp : TTIR_NamedOp<"fill", [AllShapesMatch<["value", "result"]>]> { ElementsAttr:$value); let results = (outs AnyRankedTensor:$result); - - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; } def TTIR_LinearOp : TTIR_NamedOp<"linear"> { @@ -1848,10 +1750,6 @@ def TTIR_LinearOp : TTIR_NamedOp<"linear"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasCanonicalizer = 1; @@ -1872,10 +1770,6 @@ def TTIR_MatmulOp : TTIR_NamedOp<"matmul"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasCanonicalizer = 1; @@ -1902,10 +1796,6 @@ def TTIR_PermuteOp : TTIR_NamedOp<"permute"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; let hasCanonicalizer = 1; @@ -1934,10 +1824,6 @@ def TTIR_Upsample2dOp : TTIR_NamedOp<"upsample2d"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -1949,8 +1835,7 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : TTIR_ElementwiseUnaryOp { let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - + MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { @@ -1980,8 +1865,7 @@ class TTIR_GenericElementwiseBinaryOp traits = []> TTIR_ElementwiseBinaryOp { let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - + MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { @@ -2059,10 +1943,6 @@ def TTIR_ScatterOp: TTIR_NamedOp<"scatter"> { let results = (outs AnyRankedTensor:$result); let hasVerifier = 1; - - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; } //===----------------------------------------------------------------------===// @@ -2100,10 +1980,6 @@ def TTIR_AllGatherOp : TTIR_NamedOp<"all_gather"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -2120,10 +1996,6 @@ def TTIR_AllReduceOp : TTIR_NamedOp<"all_reduce"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } @@ -2189,10 +2061,6 @@ def TTIR_MeshShardOp : TTIR_NamedOp<"mesh_shard"> { let results = (outs AnyRankedTensor:$result); - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } - }]; - let hasVerifier = 1; } diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index c031b1eb67..cd28d94e25 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -459,6 +459,49 @@ OpTy replaceOpWithNewDPSOp(mlir::PatternRewriter &rewriter, mlir::Operation *op, return newOp; } +// detect the presence of 'getOutputsMutable()' in 'Op': +template +inline constexpr bool has_variadic_outputs = false; + +template +inline constexpr bool has_variadic_outputs< + Op, std::void_t().getOutputsMutable())>> = true; + +namespace impl { + +template +struct getDpsOutputs { + static mlir::MutableOperandRange evaluate(Op *op) { + return op->getOutputMutable(); + } +}; + +template +struct getDpsOutputs>> { + static mlir::MutableOperandRange evaluate(Op *op) { + return op->getOutputsMutable(); + } +}; + +} // namespace impl + +// A helper for simplifying DPS tablegen derivations with 'arguments' of any +// form in {AnyRankedTensor:$output, Variadic:$outputs}. +// +// If a base tablegen 'class' adds this extra class declaration, derived 'def's +// don't need to overrride it just to switch from single to variadic type of +// '$outputs' (or vice versa): +// ... +// clang-format off +// let extraClassDeclaration = [{ +// MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } +// }] +// clang-format on +template +mlir::MutableOperandRange getDpsOutputs(Op *op) { + return impl::getDpsOutputs::evaluate(op); +} + } // namespace ttmlir::utils #endif From ffdd6271450749815d094d783928d6ec771e0e0d Mon Sep 17 00:00:00 2001 From: Vlad Roubtsov Date: Wed, 5 Mar 2025 19:07:56 +0000 Subject: [PATCH 2/6] rm CB operands from ttir.generic --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 11 +- .../TTIRToTTMetal/TTIRToTTMetal.cpp | 1310 ----------------- lib/Dialect/TTIR/IR/TTIROps.cpp | 12 +- .../TTIR/bufferization/bufferization.mlir | 2 +- .../TTIR/bufferization/memory_effects.mlir | 4 +- .../TTIR/generic/generic_negative.mlir | 10 +- .../TTIR/generic/generic_region_ops.mlir | 2 +- .../generic/generic_region_ops_negative.mlir | 7 +- .../Dialect/TTIR/loops/linearize_memref.mlir | 4 +- .../Dialect/TTIR/ttir_generic_hoist.mlir | 3 +- 10 files changed, 19 insertions(+), 1346 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index baa419c9cf..9cffa8bfa1 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -103,25 +103,16 @@ def TTIR_GenericOp : TTIR_BufferizableOp<"generic", [AttrSizedOperandSegments, N }]; let arguments = (ins Variadic:$inputs, - Variadic:$cbs, Variadic:$outputs, TT_GridAttr:$grid, AffineMapArrayAttr:$indexing_maps, - TT_IteratorTypeArrayAttr:$iterator_types, - DefaultValuedOptionalAttr:$operand_cb_mapping); + TT_IteratorTypeArrayAttr:$iterator_types); let results = (outs Variadic:$results); let regions = (region VariadicRegion:$regions); let hasVerifier = 1; let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } - - Value getMatchingOperand(size_t blockArgIndex) { - assert(blockArgIndex < getInputs().size() + getOutputs().size() && - "blockArgIndex should be within the range of inputs and outputs"); - return blockArgIndex < getInputs().size() ? - getOperand(blockArgIndex) : getOperand(blockArgIndex + getCbs().size()); - } }]; } diff --git a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp index ec63ef8f55..e118e6ec96 100644 --- a/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp +++ b/lib/Conversion/TTIRToTTMetal/TTIRToTTMetal.cpp @@ -544,1315 +544,6 @@ class TTIRToTTMetalLayoutRewriter : public OpRewritePattern { }; } // namespace -namespace { -class TTIRToTTMetalEnqueueProgramRewriter - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - bool hasUnloweredTTIRKernel(ttir::GenericOp op) const { - bool exists = false; - op->getRegion(0).walk([&exists](Operation *op) { - if (isa(op)) { - exists = true; - } - }); - return exists; - } - - ttkernel::CBPort getPort(unsigned argNumber, - std::int64_t numDPSInputs) const { - std::int64_t operandInOutPartition = numDPSInputs; - std::uint32_t portIdx = 0; - if (argNumber < static_cast(operandInOutPartition)) { - assert(argNumber < 8 && "Exceeds max 8 input ports"); - portIdx = ttmlir::utils::enum_as_int(ttkernel::CBPort::In0) + argNumber; - } else { - assert((argNumber - operandInOutPartition) < 8 && - "Exceeds max 8 output ports"); - portIdx = ttmlir::utils::enum_as_int(ttkernel::CBPort::Out0) + - (argNumber - operandInOutPartition); - } - std::optional maybePort = - ttkernel::symbolizeCBPort(portIdx); - assert(maybePort.has_value() && "Expected legal port value"); - return maybePort.value(); - } - - // Returns permutation of the memref affine map such that reduced dims are - // placed at the end. - AffineMap getPermutedAffineMap(AffineMap map, - ArrayRef reductionDims) const { - SmallVector permutation; - for (size_t i = 0; i < reductionDims.size(); i++) { - if (!reductionDims[i]) { - permutation.push_back(i); - } - } - if (permutation.size() == map.getNumDims()) { - return map; - } - for (size_t i = 0; i < reductionDims.size(); i++) { - if (reductionDims[i]) { - permutation.push_back(i); - } - } - return map.getPermutationMap(permutation, map.getContext()); - } - - // This routine evaluates the memref's affine map with it's shape to return a - // single result affine map, e.g.: - // - Given: shape{2, 4} and affine_map<(d0, d1) -> (d0, d1)> - // - Becomes: affine_map<(d0, d1) -> (d0 * 4 + d1) - // This is useful for evaluating iterator increment steps between each loop. - AffineMap getAffineIterator(MemRefType memref, - ArrayRef reducedMemrefDims) const { - ArrayRef shape = memref.getShape(); - SmallVector physShape(shape); - AffineMap physShapeMap = getPermutedAffineMap( - memref.getLayout().getAffineMap(), reducedMemrefDims); - - mlir::AffineExpr resultExpr = getAffineConstantExpr(0, memref.getContext()); - int volume = 1; - for (int i = static_cast(physShape.size()) - 1; i >= 0; i--) { - mlir::AffineExpr dimExpr = physShapeMap.getResult(i); - mlir::AffineExpr strideExpr = - getAffineConstantExpr(volume, memref.getContext()); - resultExpr = dimExpr * strideExpr + resultExpr; - volume *= physShape[i]; - } - return AffineMap::get(physShape.size(), 0, resultExpr, memref.getContext()); - } - - Value i32(std::int32_t value, OpBuilder &builder) const { - return builder - .create(builder.getUnknownLoc(), - builder.getI32Type(), - builder.getI32IntegerAttr(value)) - .getResult(); - } - - struct LoopNest { - SmallVector loops; - SmallVector loopRegions; - SmallVector blockArgIteratorMapping; - }; - - // Creates a loop nest that walks the input/output operand tiles in the shard. - // Converts this: - // %0 = arith.add(%1, %2) : tensor<2x4x!tile>, tensor<2x4x!tile> - // -> tensor<2x4x!tile> - // Into this: - // for (%i0 = 0; %i0 < 2; %i0++) - // for (%i1 = 0; %i1 < 4; %i1++) - // %ii = %i0 * 4 + %i1 - // %3 = ttkernel.add_tiles(%1, %2, %ii, %ii) - LoopNest createLoopNest(ArrayRef blockArguments, - ArrayRef reducedMemrefDims, - std::int64_t numDPSInputs, OpBuilder &builder) const { - Value output = blockArguments[numDPSInputs]; - ttkernel::CBType outputTy = mlir::cast(output.getType()); - MemRefType outputMemref = outputTy.getMemref(); - AffineMap outputAffineMap = outputMemref.getLayout().getAffineMap(); - size_t mapRank = outputAffineMap.getNumDims(); - - // Uniquify the iterators, i.e. operands that have identical access pattern - // can be shared. - llvm::MapVector iteratorMaps; - auto getOrInsertIterator = [&iteratorMaps, &builder, - this](AffineMap affineIterator) { - if (iteratorMaps.find(affineIterator) == iteratorMaps.end()) { - iteratorMaps[affineIterator] = i32(0, builder); - } - return iteratorMaps[affineIterator]; - }; - - // Map block arguments to their respective unique iterators Values - SmallVector iterators; - iterators.resize(blockArguments.size()); - for (BlockArgument operand : blockArguments) { - auto cbType = mlir::cast(operand.getType()); - AffineMap affineIterator = - getAffineIterator(cbType.getMemref(), reducedMemrefDims); - - assert(affineIterator.getNumDims() == mapRank); - iterators[operand.getArgNumber()] = getOrInsertIterator(affineIterator); - } - - // Walking shape is the shape which should be traversed with loop nest. For - // eltwise ops, all operands have the same shape so we can just use the - // first operand. Reduce ops are kind of unary ops so we can use the first - // operand as well. - auto firstArg = - mlir::cast(blockArguments.front().getType()); - SmallVector walkingShape = - getPermutedAffineMap(firstArg.getMemref().getLayout().getAffineMap(), - reducedMemrefDims) - .compose(firstArg.getMemref().getShape()); - - // Map block arguments to their respective unique iterator offset in the - // map. This is needed by the caller to know how to wire the iterators into - // the ttkernel tile operation. - SmallVector blockArgIteratorMapping; - blockArgIteratorMapping.resize(blockArguments.size()); - for (BlockArgument operand : blockArguments) { - auto cbType = mlir::cast(operand.getType()); - AffineMap affineIterator = - getAffineIterator(cbType.getMemref(), reducedMemrefDims); - auto *match = iteratorMaps.find(affineIterator); - assert(match != iteratorMaps.end()); - blockArgIteratorMapping[operand.getArgNumber()] = - std::distance(iteratorMaps.begin(), match); - } - - // Convert the map data structure into a vector because it's easier to work - // with when creating the loop nest below. - SmallVector uniqueIterators; - for (auto [affineMap, iterator] : iteratorMaps) { - uniqueIterators.push_back(iterator); - } - - // Create loop nest - // The loop nest is created from outermost to innermost. The innermost loop - // is special in the sense that it implements the actual iterator increment - // and the tile operation. The outer loops are responsible for fixing up the - // iterator offset for the current dimension if there was a stride or we're - // accessing the tiles in non-row-major order. - // - // iterators are just ints that correspond to absolute offsets in the CB. - // They walk the order defined by the affine map associated with the memref. - LoopNest loopNest; - loopNest.blockArgIteratorMapping = blockArgIteratorMapping; - SmallVector loops; - SmallVector loopRegions; - SmallVector> iteratorsNest = {uniqueIterators}; - for (unsigned dim = 0; dim < mapRank; ++dim) { - OpBuilder regionBuilder(builder); - if (!loopNest.loopRegions.empty()) { - regionBuilder = OpBuilder(loopNest.loopRegions.back()); - } - // Loop variables, these are decoupled from the iterators - Value lowerBound = i32(0, regionBuilder); - Value upperBound = i32(walkingShape[dim], regionBuilder); - Value loopStep = i32(1, regionBuilder); - scf::ForOp forOp = regionBuilder.create( - output.getLoc(), lowerBound, upperBound, loopStep, - iteratorsNest.back()); - loopNest.loops.push_back(forOp); - - SmallVector innerIndexStep(mapRank, 0); - innerIndexStep[dim] = 1; - bool innerLoop = dim == (mapRank - 1); - - if (innerLoop) { - OpBuilder innerLoopRegion(loopNest.loops.back().getRegion()); - SmallVector innerIndices; - int i = 0; - for (auto [affineMap, iterator] : iteratorMaps) { - // Calculate how far a single step in the inner dim is. - SmallVector innerOffset = - affineMap.compose(innerIndexStep); - assert(innerOffset.size() == 1); - innerIndices.push_back(innerLoopRegion.create( - output.getLoc(), forOp.getRegionIterArg(i), - i32(innerOffset[0], innerLoopRegion), - arith::IntegerOverflowFlagsAttr::get( - innerLoopRegion.getContext(), - arith::IntegerOverflowFlags::nsw))); - ++i; - } - innerLoopRegion.create(output.getLoc(), innerIndices); - } - - // Backpedal and adjust the iterator offset for the current dimension. - if (dim > 0) { - SmallVector outerIndices; - SmallVector outerIndexStep(mapRank, 0); - outerIndexStep[dim - 1] = 1; - int i = 0; - for (auto [affineMap, iterator] : iteratorMaps) { - // Calculate how far a single step in the inner dim is. - SmallVector innerOffset = - affineMap.compose(innerIndexStep); - assert(innerOffset.size() == 1); - // Calculate how far a single step in the outer dim is. - SmallVector outerOffset = - affineMap.compose(outerIndexStep); - assert(outerOffset.size() == 1); - // Multiply by the number of steps that the inner loop took. - // FIXME: test this for higher dims - std::int64_t offset = - outerOffset[0] - innerOffset[0] * walkingShape[dim]; - outerIndices.push_back(regionBuilder.create( - output.getLoc(), forOp.getResult(i), i32(offset, regionBuilder), - arith::IntegerOverflowFlagsAttr::get( - regionBuilder.getContext(), - arith::IntegerOverflowFlags::nsw))); - ++i; - } - regionBuilder.create(output.getLoc(), outerIndices); - } - - loopNest.loopRegions.push_back(&loopNest.loops.back().getRegion()); - iteratorsNest.emplace_back(forOp.getRegionIterArgs()); - } - - return loopNest; - } - - void convertInitUnaryOp(Operation &arithOrMathOp, - ArrayRef cbOperands, - OpBuilder &builder) const { - assert(cbOperands.size() == 2 && - "Expected one input and one output CB for unary op."); - - auto inCB = cbOperands[0]; - auto outCB = cbOperands[1]; - - // All unary ops have common init function and specialized init function. - builder.create(arithOrMathOp.getLoc(), inCB, - outCB); - - if (mlir::isa(arithOrMathOp)) { - builder.create(arithOrMathOp.getLoc()); - } else { - llvm_unreachable("Unhandled unary op init conversion."); - } - } - - void convertInitBinaryOp(Operation &arithOrMathOp, - ArrayRef cbOperands, - OpBuilder &builder) const { - assert(cbOperands.size() == 3 && - "Expected two input and one output CB for binary op."); - - auto inCB0 = cbOperands[0]; - auto inCB1 = cbOperands[1]; - auto outCB = cbOperands[2]; - - // All binary ops have common init function and specialized init function. - builder.create(arithOrMathOp.getLoc(), - inCB0, inCB1, outCB); - - if (mlir::isa(arithOrMathOp)) { - builder.create(arithOrMathOp.getLoc(), inCB0, - inCB1); - } else if (mlir::isa(arithOrMathOp)) { - builder.create(arithOrMathOp.getLoc(), inCB0, - inCB1); - } else if (mlir::isa(arithOrMathOp)) { - builder.create(arithOrMathOp.getLoc()); - } else if (mlir::isa(arithOrMathOp)) { - builder.create(arithOrMathOp.getLoc()); - } else { - llvm_unreachable("Unhandled binary op init conversion."); - } - } - - void convertInitReduceOp(Operation &reduceOp, ttkernel::ReduceDim reduceDim, - ArrayRef cbOperands, - OpBuilder &builder) const { - assert(cbOperands.size() == 3 && - "Expected two inputs and one output CB for reduce op."); - - auto kernelOp = mlir::cast(reduceOp); - assert(kernelOp.getOp() == "reduce"); - auto type = kernelOp.getKind() == "max" ? ttkernel::ReduceType::Max - : ttkernel::ReduceType::Sum; - builder.create( - reduceOp.getLoc(), cbOperands[0], cbOperands[2], cbOperands[1], - ttkernel::ReduceTypeAttr::get(builder.getContext(), type), - ttkernel::ReduceDimAttr::get(builder.getContext(), reduceDim)); - - // Wait for scaler to be ready. - auto one = i32(1, builder); - builder.create(reduceOp.getLoc(), cbOperands[2], - one); - } - - // Convert arith and math dialect operations into ttkernel init tile - // operations. HLK requires the FPU to be initialized before any tile ops get - // executed. We separate the init tile operation from the actual tile - // operation so that we can hoist the init tile operation outside of the loop - // nest. - void convertComputeInitOp(Operation &arithOrMathOp, - ArrayRef cbOperands, - std::int64_t numDpsInputs, - ttkernel::ReduceDim reduceDim, - OpBuilder &builder) const { - if (reduceDim != ttkernel::ReduceDim::None) { - convertInitReduceOp(arithOrMathOp, reduceDim, cbOperands, builder); - } else if (numDpsInputs == 1) { - convertInitUnaryOp(arithOrMathOp, cbOperands, builder); - } else if (numDpsInputs == 2) { - convertInitBinaryOp(arithOrMathOp, cbOperands, builder); - } else { - llvm_unreachable("Unhandled conversion for operation which is neither " - "unary nor binary nor reduce."); - } - } - - void convertComputeUnaryOp(Operation &arithOrMathOp, - ArrayRef cbOperands, - ArrayRef iterators, - SmallVector blockArgIteratorMapping, - OpBuilder &builder) const { - assert(cbOperands.size() == 2 && - "Expected one input and one output CB for unary op."); - - auto inCBTileIndex = iterators[blockArgIteratorMapping[0]]; - auto inCB = cbOperands[0]; - auto outCBTileIndex = iterators[blockArgIteratorMapping.back()]; - auto outCB = cbOperands.back(); - - auto location = arithOrMathOp.getLoc(); - - // We always operate on the first and only tile in DST register. - Value dstTileIndex = i32(0, builder); - - // MATH acquires lock on DST register. - builder.create(location); - - // For all unary ops first copy tile from input CB at inCBTileIndex to DST - // register at dstTileIndex. - builder.create(location, inCB); - builder.create(location, inCB, inCBTileIndex, - dstTileIndex); - - // Perform computation on tile in DST register on dstTileIndex (the only - // tile in DST). - if (mlir::isa(arithOrMathOp)) { - builder.create(location, dstTileIndex); - } else { - llvm_unreachable("Unhandled unary op compute conversion."); - } - - // MATH releases lock on DST. - builder.create(location); - - // PACK acquires lock on DST register. Blocked until MATH releases it. - builder.create(location); - - // Copy tile from DST at dstTileIndex to outCB at outCBTileIndex. - // outCBTileIndex increments as loops iterate, thus placing one result tile - // after another in outCB. - builder.create(location, dstTileIndex, outCB, - outCBTileIndex); - - // PACK releases lock on DST. - builder.create(location); - } - - void convertComputeBinaryOp(Operation &arithOrMathOp, - ArrayRef cbOperands, - ArrayRef iterators, - SmallVector blockArgIteratorMapping, - OpBuilder &builder) const { - assert(cbOperands.size() == 3 && - "Expected two input and one output CB for binary op."); - - // Perform computation C = A (*) B on tile A from cbOperands[0] and tile B - // from cbOperands[1] and store the result C in DST register on - // dstTileIndex. - if (mlir::isa(arithOrMathOp)) { - convertComputeBinaryFPUOp( - arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, - builder); - } else if (mlir::isa(arithOrMathOp)) { - commonComputeMulOp(arithOrMathOp, cbOperands, iterators, - blockArgIteratorMapping, builder); - } else if (mlir::isa(arithOrMathOp)) { - - SmallVector operandIndicesRecip; - // For DIV, input 1 is going through reciprocal. - operandIndicesRecip.push_back(1); - commonComputeRecipOp(arithOrMathOp, cbOperands, iterators, - blockArgIteratorMapping, builder, - operandIndicesRecip); - - auto inCB0 = cbOperands[0]; - auto inCB1 = cbOperands[1]; - auto location = arithOrMathOp.getLoc(); - - Value one = i32(1, builder); - builder.create(location, inCB1, one); - - builder.create(location, inCB0, inCB1); - - commonComputeMulOp(arithOrMathOp, cbOperands, iterators, - blockArgIteratorMapping, builder); - - builder.create(location, inCB1, one); - } else if (mlir::isa(arithOrMathOp)) { - convertComputeBinarySFPUOp( - arithOrMathOp, cbOperands, iterators, blockArgIteratorMapping, - builder); - } else { - llvm_unreachable("Unhandled conversion for operation which is neither " - "unary nor binary."); - } - } - - template - void convertComputeBinaryFPUOp( - Operation &arithOrMathOp, ArrayRef cbOperands, - ArrayRef iterators, - const SmallVector &blockArgIteratorMapping, - OpBuilder &builder) const { - auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; - auto inCB0 = cbOperands[0]; - auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; - auto inCB1 = cbOperands[1]; - auto outCB = cbOperands[2]; - auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; - - auto location = arithOrMathOp.getLoc(); - - Value dstIndex = i32(0, builder); - - // acquire DST register lock (MATH) - builder.create(location); - { - builder.create(location, inCB0, inCB1, inCB0TileIndex, - inCB1TileIndex, dstIndex); - } - builder.create(location); - // release DST register lock (MATH) - - // acquire DST register lock (PACK) - builder.create(location); - { - builder.create(location, dstIndex, outCB, - outCBTileIndex); - } - builder.create(location); - // release DST register lock (PACK) - } - - template - void convertComputeBinarySFPUOp( - Operation &arithOrMathOp, ArrayRef cbOperands, - ArrayRef iterators, - const SmallVector &blockArgIteratorMapping, - OpBuilder &builder) const { - auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; - auto inCB0 = cbOperands[0]; - auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; - auto inCB1 = cbOperands[1]; - auto outCB = cbOperands[2]; - auto outCBTileIndex = iterators[blockArgIteratorMapping[2]]; - - auto location = arithOrMathOp.getLoc(); - - Value dstLhsTileIndex = i32(0, builder); - Value dstRhsTileIndex = i32(1, builder); // note: rhs is always lhs+1 - - // acquire DST register lock (MATH) - builder.create(location); - { - // copy inCB0[inCB0TileIndex] and inCB1[inCB1TileIndex] to DST: - builder.create(location, inCB0); - builder.create(location, inCB0, inCB0TileIndex, - dstLhsTileIndex); - builder.create(location, inCB1); - builder.create(location, inCB1, inCB1TileIndex, - dstRhsTileIndex); - // SFPU operates on DST tiles: - builder.create(location, dstLhsTileIndex, - dstRhsTileIndex); - } - builder.create(location); - // release DST register lock (MATH) - - // acquire DST register lock (PACK) - builder.create(location); - { - builder.create(location, dstLhsTileIndex, outCB, - outCBTileIndex); - } - builder.create(location); - // release DST register lock (PACK) - } - - void commonComputeMulOp(Operation &op, ArrayRef cbOperands, - ArrayRef iterators, - SmallVector blockArgIteratorMapping, - OpBuilder &builder) const { - - auto inCB0 = cbOperands[0]; - auto inCB1 = cbOperands[1]; - auto outCB = cbOperands[2]; - auto inCB0TileIndex = iterators[blockArgIteratorMapping[0]]; - auto inCB1TileIndex = iterators[blockArgIteratorMapping[1]]; - - Value dstIndex = i32(0, builder); - - builder.create(op.getLoc()); - if (mlir::isa(op)) { - builder.create( - op.getLoc(), inCB0, inCB1, inCB0TileIndex, inCB1TileIndex, dstIndex); - } else if (mlir::isa(op)) { - // Source index for CB input 1 is 0(dstIndex), because of sync needed with - // recip. - builder.create(op.getLoc(), inCB0, inCB1, - inCB0TileIndex, dstIndex, dstIndex); - } else { - llvm_unreachable("Common compute for multiplying tiles should be called " - "only on MulFOp and DivFOp"); - } - - builder.create(op.getLoc()); - builder.create(op.getLoc()); - builder.create(op.getLoc(), dstIndex, outCB, - iterators[blockArgIteratorMapping[2]]); - builder.create(op.getLoc()); - } - - void commonComputeRecipOp(Operation &op, ArrayRef cbOperands, - ArrayRef iterators, - SmallVector blockArgIteratorMapping, - OpBuilder &builder, - SmallVector &operandIndices) const { - Value dstIndex = i32(0, builder); - Value one = i32(1, builder); - - auto inputCB = cbOperands[operandIndices[0]]; - auto outputCB = inputCB; - - builder.create(op.getLoc(), inputCB); - builder.create(op.getLoc(), inputCB, one); - builder.create(op.getLoc()); - builder.create(op.getLoc()); - builder.create(op.getLoc(), inputCB, dstIndex, - dstIndex); - builder.create(op.getLoc(), dstIndex); - builder.create(op.getLoc()); - - builder.create(op.getLoc()); - builder.create(op.getLoc(), dstIndex, outputCB, - dstIndex); - builder.create(op.getLoc()); - builder.create(op.getLoc(), outputCB, one); - } - - void convertComputeReduceOp(Block *computeBlock, Operation &op, - ArrayRef cbOperands, - ArrayRef iterators, - SmallVector blockArgIteratorMapping, - ttkernel::ReduceDim reduceDim, - LoopNest &loopNest) const { - assert(reduceDim != ttkernel::ReduceDim::None); - - auto kernelOp = mlir::cast(op); - assert(kernelOp.getOp() == "reduce"); - auto type = kernelOp.getKind() == "max" ? ttkernel::ReduceType::Max - : ttkernel::ReduceType::Sum; - OpBuilder mainBuilder(computeBlock, computeBlock->begin()); - mainBuilder.setInsertionPointToEnd(computeBlock); - - OpBuilder innerLoopBuilder(&loopNest.loopRegions.back()->front(), - loopNest.loopRegions.back()->front().begin()); - auto dstIndex = i32(0, innerLoopBuilder); - - innerLoopBuilder.create( - op.getLoc(), cbOperands[0], cbOperands[2], - iterators[blockArgIteratorMapping[0]], - iterators[blockArgIteratorMapping[0]], dstIndex, - ttkernel::ReduceTypeAttr::get(innerLoopBuilder.getContext(), type), - ttkernel::ReduceDimAttr::get(innerLoopBuilder.getContext(), reduceDim)); - - size_t numLoopRegions = loopNest.loopRegions.size(); - size_t numReducedDims = reduceDim == ttkernel::ReduceDim::Row || - reduceDim == ttkernel::ReduceDim::Col - ? 1 - : 2; - - Block *packingBlock = - numReducedDims == numLoopRegions - ? computeBlock - : &loopNest.loopRegions[numLoopRegions - 1 - numReducedDims] - ->getBlocks() - .front(); - OpBuilder packingBuilder(packingBlock, packingBlock->begin()); - - packingBuilder.create( - computeBlock->front().getLoc()); - - Value packSingleTile = i32(0, packingBuilder); - Value packingTileIndex = - numReducedDims == numLoopRegions - ? packSingleTile - : loopNest.loops[numLoopRegions - 1 - numReducedDims] - .getRegionIterArgs()[blockArgIteratorMapping.back()]; - - if (packingBlock->mightHaveTerminator()) { - packingBuilder.setInsertionPoint(packingBlock->getTerminator()); - } else { - packingBuilder.setInsertionPointToEnd(packingBlock); - } - - packingBuilder.create( - computeBlock->front().getLoc()); - packingBuilder.create( - computeBlock->front().getLoc()); - packingBuilder.create( - computeBlock->front().getLoc(), i32(0, packingBuilder), cbOperands[1], - packingTileIndex); - packingBuilder.create( - computeBlock->front().getLoc()); - } - - // Convert arith and math dialect operations into ttkernel tile operations. - // Here iterators are the block arguments from the innermost scf.for loop. - // The iterators are unique-ified so we need blockArgIteratorMapping to - // recover which top level tensor operand is associated with which iterator. - void convertComputeOp(Block *computeBlock, Operation &arithOrMathOp, - LoopNest &loopNest, ArrayRef cbOperands, - ArrayRef iterators, - ttkernel::ReduceDim reduceDim, - SmallVector blockArgIteratorMapping, - OpBuilder &innerLoopBuilder, - std::int64_t numDpsInputs) const { - - if (reduceDim != ttkernel::ReduceDim::None) { - convertComputeReduceOp(computeBlock, arithOrMathOp, cbOperands, iterators, - blockArgIteratorMapping, reduceDim, loopNest); - } else if (numDpsInputs == 1) { - convertComputeUnaryOp(arithOrMathOp, cbOperands, iterators, - blockArgIteratorMapping, innerLoopBuilder); - } else if (numDpsInputs == 2) { - convertComputeBinaryOp(arithOrMathOp, cbOperands, iterators, - blockArgIteratorMapping, innerLoopBuilder); - } else { - llvm_unreachable("Unhandled conversion for operation which is neither " - "unary nor binary."); - } - } - - // Builds instructions to execute before looping over tiles has started. - void buildInitSection(Operation &arithOrMathOp, OpBuilder &builder, - ArrayRef cbOperands, - ttkernel::ReduceDim reduceDim, - std::int64_t numDPSInputs) const { - convertComputeInitOp(arithOrMathOp, cbOperands, numDPSInputs, reduceDim, - builder); - } - - // Builds nested loops which loop over tensor tiles after initalization is - // done and computation to perform on each tile over which loops iterate. - void buildLoopsAndComputation(Block *computeBlock, Operation &arithOrMathOp, - OpBuilder &builder, - ArrayRef reducedMemrefDims, - ttkernel::ReduceDim reduceDim, - ArrayRef &cbOperands, - std::int64_t numDPSInputs) const { - // Create loops which iterate over tiles in tensor. - LoopNest loopNest = - createLoopNest(cbOperands, reducedMemrefDims, numDPSInputs, builder); - assert(loopNest.loops.size() == 2 && "Expected only two loops!"); - - // The loop nest is created from outermost to innermost. Get the inner loop - // and place computation calls inside it. - Region *innerLoopRegion = loopNest.loopRegions.back(); - ArrayRef iterators = - loopNest.loops.back().getRegionIterArgs(); - SmallVector blockArgIteratorMapping = - loopNest.blockArgIteratorMapping; - - OpBuilder innerLoopBuilder(&innerLoopRegion->front(), - innerLoopRegion->front().begin()); - - // Call compute function to execute on each tile. Result will be stored in - // DST. - convertComputeOp(computeBlock, arithOrMathOp, loopNest, cbOperands, - iterators, reduceDim, blockArgIteratorMapping, - innerLoopBuilder, numDPSInputs); - } - - // Builds instructions to execute after loops are finished. - void buildEndSection(OpBuilder &enqueueProgramBlockBuilder, - Block *origGenericOpBlock) const { - // Place return op at the end of block, after loops. - enqueueProgramBlockBuilder.create( - origGenericOpBlock->getTerminator()->getLoc()); - } - - ttkernel::ReduceDim getReduceDim(ArrayRef reducedMemrefDims) const { - if (reducedMemrefDims.size() == 1) { - return reducedMemrefDims[0] ? ttkernel::ReduceDim::Scalar - : ttkernel::ReduceDim::None; - } - if (reducedMemrefDims.size() == 2) { - if (reducedMemrefDims[0] && reducedMemrefDims[1]) { - return ttkernel::ReduceDim::Scalar; - } - if (reducedMemrefDims[0]) { - return ttkernel::ReduceDim::Col; - } - if (reducedMemrefDims[1]) { - return ttkernel::ReduceDim::Row; - } - return ttkernel::ReduceDim::None; - } - llvm_unreachable("Unhandled reduction dims"); - } - - // Convert the original block into a lowered block that contains a fully - // expanded loop nest and inner loop that implements the underlying arith or - // math operation as a tile operation. - void lowerBlock(Block *origGenericOpBlock, Block *computeBlock, - ArrayAttr iteratorTypes, ArrayAttr indexingMaps, - std::int64_t numDPSInputs) const { - Block::OpListType &operations = origGenericOpBlock->getOperations(); - assert(operations.size() == 2); - Operation::user_range users = operations.front().getUsers(); - assert(users.begin() != users.end()); - assert(mlir::isa(*users.begin())); - assert(computeBlock->getNumArguments() > numDPSInputs); - - auto outputMemref = mlir::cast( - computeBlock->getArgument(numDPSInputs).getType()) - .getMemref() - .getLayout() - .getAffineMap(); - size_t j = iteratorTypes.size() - 1; - uint32_t outputRank = outputMemref.getNumDims(); - SmallVector reducedMemrefDims(outputRank, false); - - // Collect the reduction dims going from innermost to outermost. - assert(outputRank <= iteratorTypes.size()); - for (int i = outputRank - 1; i >= 0; --i, --j) { - auto itType = iteratorTypes[j]; - if (mlir::cast(itType).getValue() == - IteratorType::Reduction) { - reducedMemrefDims[i] = true; - } - } - - auto kernelReduceDim = getReduceDim(reducedMemrefDims); - - OpBuilder builder(computeBlock, computeBlock->begin()); - Operation &arithOrMathOp = operations.front(); - auto cbOperands = computeBlock->getArguments(); - - buildInitSection(arithOrMathOp, builder, cbOperands, kernelReduceDim, - numDPSInputs); - buildLoopsAndComputation(computeBlock, arithOrMathOp, builder, - reducedMemrefDims, kernelReduceDim, cbOperands, - numDPSInputs); - buildEndSection(builder, origGenericOpBlock); - } - - struct StreamedOperand { - uint64_t srcAddress; - uint64_t dstAddress; - size_t blockArgIndex; - bool hasDataMovement; - llvm::MapVector> - dataMovement; - uint64_t numTiles; - PhysicalCoreCoordMapping coordMappping; - - StreamedOperand( - uint64_t srcAddress, uint64_t dstAddress, size_t blockArgIndex, - bool hasDataMovement, - llvm::MapVector> - dataMovement, - uint64_t numTiles, const PhysicalCoreCoordMapping &coordMappping) - : srcAddress(srcAddress), dstAddress(dstAddress), - blockArgIndex(blockArgIndex), hasDataMovement(hasDataMovement), - dataMovement(dataMovement), numTiles(numTiles), - coordMappping(coordMappping) {} - }; - - std::pair, SmallVector> - getBlockArgumentTypesAsCBs(ttir::GenericOp op, - mlir::Block::BlockArgListType blockArguments, - PatternRewriter &rewriter) const { - - SmallVector rewrittenBlockArgumentTypes; - SmallVector streamedOperands; - - for (auto arg : blockArguments) { - auto port = getPort(arg.getArgNumber(), op.getInputs().size()); - auto tensor = mlir::cast(arg.getType()); - auto layout = mlir::cast(tensor.getEncoding()); - auto memref = layout.getMemref(); - - int32_t cbMapping = op.getOperandCbMapping()[arg.getArgNumber()]; - - // Operand that is directly mapped to block argument. - auto matchingOperand = op.getMatchingOperand(arg.getArgNumber()); - - // Operand that is either directly mapped to block argument or - // corresponding CB operand if it should be streamed. - auto correspondingOperand = - cbMapping == -1 ? matchingOperand : op.getCbs()[cbMapping]; - auto address = lookupAddress(correspondingOperand); - assert(address && "Expected valid address"); - - rewrittenBlockArgumentTypes.push_back( - rewriter.getType(port, address, memref)); - - uint64_t numTiles = memref.getShape()[memref.getRank() - 1] * - memref.getShape()[memref.getRank() - 2]; - - llvm::MapVector> - dataMovement; - if (layout.getStreamMode() == StreamMode::Stream) { - dataMovement = calculateDataMovement( - op.getIteratorTypes(), - mlir::cast(matchingOperand.getType()), - mlir::cast(correspondingOperand.getType()), - op.getDevice()); - } else { - dataMovement[PhysicalCoreCoord()] = - SmallVector(); - } - streamedOperands.push_back(StreamedOperand( - lookupAddress(matchingOperand), lookupAddress(correspondingOperand), - arg.getArgNumber(), layout.getStreamMode() == StreamMode::Stream, - dataMovement, numTiles, - // TODO(rpavlovic) fix the assumption that input is always in L1. - PhysicalCoreCoordMapping::getMemorySpaceMapping( - op.getDevice().getChipIds(), op.getSystemDesc().getChipDescs(), - MemorySpace::DeviceL1))); - } - - return {rewrittenBlockArgumentTypes, streamedOperands}; - } - - llvm::MapVector> - calculateDataMovement(ArrayAttr iteratorTypes, const RankedTensorType &src, - const RankedTensorType &dst, DeviceAttr device) const { - auto srcLayout = mlir::cast(src.getEncoding()); - assert(srcLayout.isTiled()); - - auto dstLayout = mlir::cast(dst.getEncoding()); - assert(dstLayout.isTiled()); - - assert(iteratorTypes.size() >= 2 && "Expected at least 2 iterator types"); - - auto lastDimIterType = - mlir::cast(iteratorTypes[iteratorTypes.size() - 1]); - auto penLastDimIterType = - mlir::cast(iteratorTypes[iteratorTypes.size() - 2]); - bool transposeLast2Dims = false; - if (penLastDimIterType.getValue() == IteratorType::Reduction && - lastDimIterType.getValue() == IteratorType::Parallel) { - transposeLast2Dims = true; - } - - auto srcMap = srcLayout.getIdentityTileLinearMap(); - if (transposeLast2Dims) { - auto mapRank = srcMap.getNumResults(); - SmallVector permutation; - for (size_t i = 0; i < mapRank; i++) { - permutation.push_back(i); - } - permutation[mapRank - 1] = mapRank - 2; - permutation[mapRank - 2] = mapRank - 1; - srcMap = srcMap.getPermutationMap(permutation, srcMap.getContext()); - } - - auto srcShape = srcMap.compose(srcLayout.getTiledShape(src.getShape())); - auto srcProjection = srcLayout.projectOnto( - srcMap, device.getMapForMemorySpace(srcLayout.getMemorySpace())); - - auto dstMap = dstLayout.getIdentityTileLinearMap(); - auto dstShape = dstLayout.getTiledShape(dst.getShape()); - auto dstProjection = dstLayout.projectOnto( - dstMap, device.getMapForMemorySpace(dstLayout.getMemorySpace())); - - // dstProjection is composed with srcMap to cover the case where srcMap is - // transposed. Then its shape is transposed too, therefore dstProjection - // must work with transposed shape. - auto dm = TTIRToTTMetalLayoutRewriter::calculateDataMovement( - srcShape, srcLayout.getElementSizeBytes(), srcProjection, - dstProjection.compose(srcMap), - TTIRToTTMetalLayoutRewriter::NocTx::Type::Read, - dstLayout.getMemrefSizeBytes()); - - return dm; - } - - static bool needScaler(ttir::GenericOp op) { - Block &block = op.getRegion(0).front(); - Operation &firstOp = block.getOperations().front(); - if (!mlir::isa(firstOp)) { - return false; - } - - auto kernelOp = mlir::cast(firstOp); - if (kernelOp.getOp() != "reduce") { - return false; - } - - return true; - } - - static Type addReduceScaler(ttir::GenericOp op, Block *dmBlock, - ArrayRef rewrittenBlockArgumentTypes, - PatternRewriter &rewriter) { - Block &block = op.getRegion(0).front(); - Operation &firstOp = block.getOperations().front(); - if (!mlir::isa(firstOp)) { - return nullptr; - } - - auto kernelOp = mlir::cast(firstOp); - if (kernelOp.getOp() != "reduce") { - return nullptr; - } - - // Take the port after the last input arg. - auto scalerCBPort = ttkernel::symbolizeCBPort( - ttmlir::utils::enum_as_int(ttkernel::CBPort::In0) + - op.getInputs().size()) - .value(); - auto inputCB = - mlir::cast(*rewrittenBlockArgumentTypes.begin()); - auto inputTT = mlir::cast(inputCB.getMemref().getElementType()); - assert(inputTT); - - // Single tile memref that will be used for the scaler. - MemRefType tileMemref = MemRefType::get( - {1, 1}, inputTT, - mlir::AffineMap::getMultiDimIdentityMap(2, op->getContext()), - MemorySpaceAttr::get(op->getContext(), MemorySpace::DeviceL1)); - - auto scalerCBType = ttkernel::CBType::get( - op.getContext(), scalerCBPort, - op.getSystemDesc().getChipDescs().front().getScratchL1RegionAddress(), - tileMemref); - auto scalerCB = dmBlock->addArgument(scalerCBType, op.getLoc()); - - auto reduceKind = kernelOp.getKind(); - float scalerValue = 0.; - if (reduceKind == "sum" || reduceKind == "max") { - scalerValue = 1.; - } - - if (reduceKind == "mean") { - int64_t numElements = 1; - auto inputType = - mlir::cast(op.getInputs()[0].getType()); - - for (int64_t dim = 0; dim < inputType.getRank(); ++dim) { - auto iteratorType = - mlir::cast(op.getIteratorTypes()[dim]); - if (iteratorType.getValue() == IteratorType::Reduction) { - numElements *= inputType.getShape()[dim]; - } - } - scalerValue = 1. / numElements; - } - - generateScaler(op.getLoc(), dmBlock, scalerCB, scalerValue); - - return scalerCBType; - } - - // Given float value fValue, generate a scaler tile with value fValue. Tile - // should be populated in pattern such that only first row of each face is - // populated with fValue, while the rest of the tile is zeroed out. - static void generateScaler(Location loc, Block *block, - BlockArgument &scalerCB, float fValue) { - OpBuilder builder(block, block->begin()); - - // Assumption is that scalerCB is tile of F16/BF16 values. Converting from - // float to 2byte float is truncating mantissa bits. To support lesser - // fortmats we need to add conversion from float to those formats, which is - // trickier than just truncating mantissa bits. - uint32_t iValue = *reinterpret_cast(&fValue); - iValue >>= 16; - uint16_t iValue16 = iValue & 0xFFFF; - auto scalerConst = builder.create( - loc, builder.getI32Type(), - builder.getI32IntegerAttr(iValue16 | (iValue16 << 16))); - - // Reserve single tile. - auto oneConst = builder.create( - loc, builder.getI32Type(), builder.getI32IntegerAttr(1)); - builder.create(loc, scalerCB, oneConst); - - // Prepare zero region read. - auto zerosBase = - builder.create(loc, builder.getI32Type()); - auto zerosNocAddr = - builder.create(loc, zerosBase->getResult(0)); - auto memZerosSize = - builder.create(loc, builder.getI32Type()); - builder.create(loc, zerosNocAddr, - memZerosSize); - - // Prepare pointer to scalerCB tile. - auto writeAddr = builder.create(loc, scalerCB); - auto ptr = builder.create(loc, writeAddr); - - // Zeros are read in few packets, so we need to calculate how many reads. - // Assumption is that scalerCB is 2048 bytes. - auto lowerBound = builder.create( - loc, builder.getI32Type(), builder.getI32IntegerAttr(0)); - auto cbSizeBytes = builder.create( - loc, builder.getI32Type(), builder.getI32IntegerAttr(2048)); - auto numZerosReads = builder.create( - loc, builder.getI32Type(), cbSizeBytes, memZerosSize); - auto step = builder.create(loc, builder.getI32Type(), - builder.getI32IntegerAttr(1)); - - // Generate loop to read zeros and fill tile. Move address by memZerosSize - // in each iteration. - auto forOp = builder.create(loc, lowerBound, numZerosReads, - step, ValueRange(writeAddr)); - builder.setInsertionPointToStart(forOp.getBody()); - builder.create( - loc, zerosNocAddr, forOp.getRegionIterArg(0)); - SmallVector newAddr; - newAddr.push_back(builder.create( - loc, forOp.getRegionIterArg(0), memZerosSize, - mlir::arith::IntegerOverflowFlagsAttr::get( - builder.getContext(), mlir::arith::IntegerOverflowFlags::nsw))); - builder.create(loc, newAddr); - - builder.setInsertionPointAfter(forOp); - builder.create(loc); - - // Fill the tile in 2 nested loops. Outer loop is for 4 faces. - auto numFaces = builder.create( - loc, builder.getI32Type(), builder.getI32IntegerAttr(4)); - auto fillingOuterLoop = builder.create( - loc, lowerBound, numFaces, step, ValueRange{}); - - auto faceIndex = fillingOuterLoop.getInductionVar(); - builder.setInsertionPointToStart(fillingOuterLoop.getBody()); - - // In each face, we want to fill first row of 16 datums, each datum being - // sized 2B. So we need to fill 32B in each face. Since we packed 4B in - // scalerConst, this gives us 8 stores in each face. - // After each face, we need to move to next face, which is 512B away (or - // 128x4B). - auto bytesStride = builder.create( - loc, builder.getI32Type(), builder.getI32IntegerAttr(128)); - auto offset = builder.create( - loc, faceIndex, bytesStride, - mlir::arith::IntegerOverflowFlagsAttr::get( - builder.getContext(), mlir::arith::IntegerOverflowFlags::nsw)); - - auto numStores = builder.create( - loc, builder.getI32Type(), builder.getI32IntegerAttr(8)); - auto fillingInnerLoop = builder.create( - loc, lowerBound, numStores, step, ValueRange{}); - - builder.setInsertionPointToStart(fillingInnerLoop.getBody()); - auto fillingIdx = builder.create( - loc, offset, fillingInnerLoop.getInductionVar(), - mlir::arith::IntegerOverflowFlagsAttr::get( - builder.getContext(), mlir::arith::IntegerOverflowFlags::nsw)); - builder.create(loc, scalerConst, ptr, fillingIdx); - - // Notify that we filled the tile. - builder.setInsertionPointAfter(fillingOuterLoop); - builder.create(loc, scalerCB, oneConst); - } - - void generateDataMovementThreads( - ttir::GenericOp op, Block *tensixBlock, - ArrayRef streamedOperands, PatternRewriter &rewriter, - ttmetal::EnqueueProgramOp &metalEnqueueProgram, - ArrayRef rewrittenBlockArgumentTypes) const { - - // TODO(1159) move TTIRToTTMetalLayoutRewriter::createDataMovementThreads - // (& other data movement logic) into common place - int dmThreadIdx = 1; - assert(!streamedOperands.empty()); - - llvm::DenseMap coordToBlock; - for (auto operand : streamedOperands) { - if (!operand.hasDataMovement) { - continue; - } - for (auto [dstCoord, srcs] : operand.dataMovement) { - Block *block = coordToBlock.find(dstCoord) == coordToBlock.end() - ? rewriter.createBlock( - &metalEnqueueProgram.getRegion(dmThreadIdx++)) - : coordToBlock[dstCoord]; - coordToBlock[dstCoord] = block; - - block->addArgument(rewrittenBlockArgumentTypes[operand.blockArgIndex], - op.getLoc()); - - auto streamedCB = block->getArgument(0); - - TTIRToTTMetalLayoutRewriter::createDataMovementThread( - op->getLoc(), block, operand.srcAddress, operand.dstAddress, srcs, - operand.coordMappping, op.getSystemDesc().getAddressAlignBytes(), - &streamedCB); - } - } - - if (needScaler(op)) { - if (coordToBlock.empty()) { - // No data movement, so we need to add a block for the scaler. - coordToBlock[PhysicalCoreCoord()] = - rewriter.createBlock(&metalEnqueueProgram.getRegion(dmThreadIdx++)); - } - - Type scalerCBType; - for (auto [coord, block] : coordToBlock) { - scalerCBType = - addReduceScaler(op, block, rewrittenBlockArgumentTypes, rewriter); - } - - // Propagate the scalerCBType to the compute thread. - tensixBlock->addArgument(scalerCBType, op.getLoc()); - } - - // Finish all blocks with return op. - for (auto [coord, block] : coordToBlock) { - OpBuilder builder = OpBuilder::atBlockEnd(block); - builder.create(op.getLoc()); - } - } - - static void - addSyncronizationForDataMovement(ttir::GenericOp op, Block *tensixBlock, - ArrayRef streamedOperands) { - for (auto operand : streamedOperands) { - if (operand.hasDataMovement) { - // There is some data movement. Let's just insert waiting command at the - // start of compute block. We assume whole block is streamed. - OpBuilder builder(tensixBlock, tensixBlock->begin()); - auto numPages = builder.create( - op.getLoc(), builder.getI32Type(), - builder.getI32IntegerAttr(operand.numTiles)); - - auto streamedCB = tensixBlock->getArgument(operand.blockArgIndex); - builder.create(op.getLoc(), streamedCB, - numPages); - - builder.setInsertionPoint(tensixBlock->getTerminator()); - builder.create(op.getLoc(), streamedCB, - numPages); - } - } - } - - LogicalResult matchAndRewrite(ttir::GenericOp op, - PatternRewriter &rewriter) const final { - // Temporary fix that allows ttir::KernelOp to be lowered directly into - // ttkernel dialect. - // if (hasUnloweredTTIRKernel(op)) { - // return failure(); - // } - - // Ensure tt-mlir/tt-metal agree on number, and set UnpackToDestMode per CB - uint32_t chipNumCBs = op.getSystemDesc().getChipDescs()[0].getNumCBs(); - constexpr uint32_t kNumCBs = 1 + ttkernel::getMaxEnumValForCBPort(); - assert(chipNumCBs == kNumCBs && "Mismatch between tt-mlir and tt-metal " - "number of CBs"); - - llvm::SmallVector unpackToDestModes( - kNumCBs, ttkernel::UnpackToDestMode::Default); - - auto tensixAttr = rewriter.getAttr( - ttkernel::MathFidelity::HiFi4, false, false, unpackToDestModes); - SmallVector kernelConfigs = {tensixAttr}; - SmallVector coreRanges = { - rewriter.getAttr(op.getGrid()), - }; - - SmallVector rewrittenBlockArgumentTypes; - SmallVector streamedOperands; - std::tie(rewrittenBlockArgumentTypes, streamedOperands) = - getBlockArgumentTypesAsCBs(op, op->getRegion(0).getArguments(), - rewriter); - - assert(!streamedOperands.empty()); - - // Minimal mapping to cover whole generic op grid w.r.t. data movement - // requirements of all operands. Today, each destination core gets its own - // unique kernel for data movement, while in the future we'll want to merge - // them into less kernels if possible and configure them through kernel - // configs. Operands that have no data movement simply cover whole op grid. - llvm::DenseMap> allDstCoords; - for (auto operand : streamedOperands) { - for (auto [dstCoord, srcs] : operand.dataMovement) { - if (allDstCoords.find(dstCoord) != allDstCoords.end() && - !operand.hasDataMovement) { - // Some other operand already added this dstCoord, we don't want to - // overwrite it by this operand which has no data movement. - continue; - } - allDstCoords.try_emplace(dstCoord, operand.hasDataMovement - ? SmallVector({1, 1}) - : op.getGrid().getShape()); - } - } - - assert(!allDstCoords.empty() && "There should be at least one dstCoord"); - - for (auto [dstCoord, gridShape] : allDstCoords) { - kernelConfigs.push_back(ttkernel::NocConfigAttr::get( - rewriter.getContext(), ttkernel::NocIndex::Noc0)); - // If no data movement transactions are needed, let's cover whole op - // grid. - coreRanges.push_back(ttmetal::CoreRangeAttr::get( - getContext(), {dstCoord.y, dstCoord.x}, gridShape)); - } - - // Wire generic's operands to enqueue program op's operands with respect to - // the CB mapping. - SmallVector inputsToEnqueueProgramOp; - for (size_t i = 0; i < op.getInputs().size(); ++i) { - auto operand = op.getOperandCbMapping()[i] == -1 - ? op.getMatchingOperand(i) - : op.getCbs()[op.getOperandCbMapping()[i]]; - inputsToEnqueueProgramOp.push_back(operand); - } - - auto metalEnqueueProgram = rewriter.create( - op.getLoc(), op.getResults().getTypes(), inputsToEnqueueProgramOp, - op.getOutputs(), rewriter.getArrayAttr(coreRanges), - rewriter.getArrayAttr(kernelConfigs), kernelConfigs.size()); - - Block *tensixBlock = &metalEnqueueProgram.getRegion(0).emplaceBlock(); - - for (auto ty : rewrittenBlockArgumentTypes) { - tensixBlock->addArgument(ty, op.getLoc()); - } - - generateDataMovementThreads(op, tensixBlock, streamedOperands, rewriter, - metalEnqueueProgram, - rewrittenBlockArgumentTypes); - - lowerBlock(&op->getRegion(0).front(), tensixBlock, op.getIteratorTypes(), - op.getIndexingMaps(), op.getInputs().size()); - - addSyncronizationForDataMovement(op, tensixBlock, streamedOperands); - - // Regions for enqueue program op are allocated up-front, but some of them - // may be empty at the end of lowering due to no data movement requirements. - // Insert return op in those regions. - for (Region ® : metalEnqueueProgram->getRegions()) { - if (reg.empty()) { - auto &block = reg.emplaceBlock(); - OpBuilder builder = OpBuilder::atBlockEnd(&block); - builder.create(op.getLoc()); - } - } - - rewriter.replaceOp(op, metalEnqueueProgram); - - return success(); - } -}; -} // namespace - namespace { class TTIRToTTMetalAllocRewriter : public OpRewritePattern { public: @@ -1903,7 +594,6 @@ void populateTTIRToTTMetalPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter & /*typeConverter*/) { patterns.add(ctx); diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 1c268a9c1c..651ecb18b2 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -2771,11 +2771,6 @@ void mlir::tt::ttir::PermuteOp::getCanonicalizationPatterns( // GenericOp verification ::mlir::LogicalResult mlir::tt::ttir::GenericOp::verify() { - // Validate CB mappings. - if (getCbs().size()) { - return emitOpError("CB mappings are deprecated and should not be used"); - } - // Output grid shape must equal the GenericOp grid shape. auto opGridShape = getGrid().getShape(); for (auto output : getOutputs()) { @@ -2871,8 +2866,6 @@ mlir::LogicalResult mlir::tt::ttir::GenericOp::bufferize( assert(getNumResults() == 1 && "GenericOp should have exactly one result"); assert(getOutputs().size() == 1 && "GenericOp should have exactly one output"); - assert(getCbs().size() == 0 && - "GenericOp should not have any cb, these are deprecated"); if (!mlir::isa(getResult(0).getType())) { return failure(); @@ -2896,9 +2889,8 @@ mlir::LogicalResult mlir::tt::ttir::GenericOp::bufferize( bufferOutputs.push_back(*maybeValue); } auto bufferGeneric = rewriter.create( - getLoc(), ValueRange(), bufferInputs, ValueRange(), bufferOutputs, - getGrid(), getIndexingMaps(), getIteratorTypes(), getOperandCbMapping(), - getNumRegions()); + getLoc(), ValueRange(), bufferInputs, bufferOutputs, getGrid(), + getIndexingMaps(), getIteratorTypes(), getNumRegions()); for (mlir::Region ®ion : bufferGeneric.getRegions()) { region.takeBody(getRegion(region.getRegionNumber())); } diff --git a/test/ttmlir/Dialect/TTIR/bufferization/bufferization.mlir b/test/ttmlir/Dialect/TTIR/bufferization/bufferization.mlir index c1bdf3aca9..3f8180a02a 100644 --- a/test/ttmlir/Dialect/TTIR/bufferization/bufferization.mlir +++ b/test/ttmlir/Dialect/TTIR/bufferization/bufferization.mlir @@ -11,7 +11,7 @@ func.func @matmul(%arg0: tensor<1x1x2x4x!tt.tile<32x32, f32>>, %arg1: tensor<1x1 // CHECK: = memref.alloc() {{.*}} : memref<1x1x2x2x!tt.tile<32x32, f32>, #l1_> %0 = tensor.empty() : tensor<1x1x2x2x!tt.tile<32x32, f32>> // CHECK: {{^ "ttir.generic".*}} - %3 = "ttir.generic"(%arg0, %arg1, %0) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map1, #map2], iterator_types = [#parallel, #parallel, #reduction], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + %3 = "ttir.generic"(%arg0, %arg1, %0) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map1, #map2], iterator_types = [#parallel, #parallel, #reduction], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<4x2x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x2x!tt.tile<32x32, f32>, #l1_>): "ttir.tile_matmul_block"(%arg2, %arg3, %arg4) : (memref<2x4x!tt.tile<32x32, f32>, #l1_>, memref<4x2x!tt.tile<32x32, f32>, #l1_>, memref<2x2x!tt.tile<32x32, f32>, #l1_>) -> () }) : (tensor<1x1x2x4x!tt.tile<32x32, f32>>, tensor<1x1x4x2x!tt.tile<32x32, f32>>, tensor<1x1x2x2x!tt.tile<32x32, f32>>) -> tensor<1x1x2x2x!tt.tile<32x32, f32>> diff --git a/test/ttmlir/Dialect/TTIR/bufferization/memory_effects.mlir b/test/ttmlir/Dialect/TTIR/bufferization/memory_effects.mlir index dd7f945919..2f5a867102 100644 --- a/test/ttmlir/Dialect/TTIR/bufferization/memory_effects.mlir +++ b/test/ttmlir/Dialect/TTIR/bufferization/memory_effects.mlir @@ -11,7 +11,7 @@ func.func @matmul_pure_tensors(%arg0: tensor<2x4x!tt.tile<32x32, f32>>, %arg1: t %0 = tensor.empty() : tensor<2x2x!tt.tile<32x32, f32>> // No uses of %3, so it should be removed. // CHECK-NOT: "ttir.generic" - %3 = "ttir.generic"(%arg0, %arg1, %0) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map1, #map2], iterator_types = [#parallel, #parallel, #reduction], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + %3 = "ttir.generic"(%arg0, %arg1, %0) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map1, #map2], iterator_types = [#parallel, #parallel, #reduction], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<4x2x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x2x!tt.tile<32x32, f32>, #l1_>): "ttir.tile_matmul_block"(%arg2, %arg3, %arg4) : (memref<2x4x!tt.tile<32x32, f32>, #l1_>, memref<4x2x!tt.tile<32x32, f32>, #l1_>, memref<2x2x!tt.tile<32x32, f32>, #l1_>) -> () }) : (tensor<2x4x!tt.tile<32x32, f32>>, tensor<4x2x!tt.tile<32x32, f32>>, tensor<2x2x!tt.tile<32x32, f32>>) -> tensor<2x2x!tt.tile<32x32, f32>> @@ -30,7 +30,7 @@ func.func @matmul_memref(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, %arg %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x2x!tt.tile<32x32, f32>, #l1_> // Ensure that the generic op is not removed. // CHECK: "ttir.generic" - "ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map1, #map2], iterator_types = [#parallel, #parallel, #reduction], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map1, #map2], iterator_types = [#parallel, #parallel, #reduction], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<4x2x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x2x!tt.tile<32x32, f32>, #l1_>): "ttir.tile_matmul_block"(%arg2, %arg3, %arg4) : (memref<2x4x!tt.tile<32x32, f32>, #l1_>, memref<4x2x!tt.tile<32x32, f32>, #l1_>, memref<2x2x!tt.tile<32x32, f32>, #l1_>) -> () }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, memref<1x1x4x2x!tt.tile<32x32, f32>, #l1_>, memref<1x1x2x2x!tt.tile<32x32, f32>, #l1_>) -> () diff --git a/test/ttmlir/Dialect/TTIR/generic/generic_negative.mlir b/test/ttmlir/Dialect/TTIR/generic/generic_negative.mlir index 66f44199aa..7c5a398d5d 100644 --- a/test/ttmlir/Dialect/TTIR/generic/generic_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/generic/generic_negative.mlir @@ -9,7 +9,7 @@ func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> - "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<2x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<2x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>): }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> () return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> @@ -25,7 +25,7 @@ func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1 func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> - "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>): }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> () return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> @@ -41,7 +41,7 @@ func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1 func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> - "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: tensor<2x4x!tt.tile<32x32, f32>, #l1_>): }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> () return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> @@ -58,7 +58,7 @@ func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1 func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #dram_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #dram_> - "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>): }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #dram_>) -> () return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #dram_> @@ -74,7 +74,7 @@ func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1 func.func @matmul(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> - "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x2x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>): }) : (memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>, memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_>) -> () return %alloc : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> diff --git a/test/ttmlir/Dialect/TTIR/generic/generic_region_ops.mlir b/test/ttmlir/Dialect/TTIR/generic/generic_region_ops.mlir index 05e7a6f72f..d9ba0d3a27 100644 --- a/test/ttmlir/Dialect/TTIR/generic/generic_region_ops.mlir +++ b/test/ttmlir/Dialect/TTIR/generic/generic_region_ops.mlir @@ -11,7 +11,7 @@ func.func @reduce_max(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> t grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], - operandSegmentSizes = array + operandSegmentSizes = array }> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>, %arg3: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>, diff --git a/test/ttmlir/Dialect/TTIR/generic/generic_region_ops_negative.mlir b/test/ttmlir/Dialect/TTIR/generic/generic_region_ops_negative.mlir index 6faa22b94d..9f3f79bd65 100644 --- a/test/ttmlir/Dialect/TTIR/generic/generic_region_ops_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/generic/generic_region_ops_negative.mlir @@ -10,21 +10,22 @@ func.func @reduce_dim_arg(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], - operandSegmentSizes = array + operandSegmentSizes = array }> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>, %arg3: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>): linalg.generic { indexing_maps = [#map, #map], - iterator_types = ["parallel", "parallel"]} + iterator_types = ["parallel", "parallel"] + } ins(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>) outs(%arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_alias>) { ^bb0(%a: !tt.tile<32x32, f32>, %b: !tt.tile<32x32, f32>): // CHECK: error: 'ttir.tile_reduce_max' op requires attribute 'reduce_dim' %4 = "ttir.tile_reduce_max" (%a, %b) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32> linalg.yield %4: !tt.tile<32x32, f32> - } + } "ttir.yield"() : () -> () }) : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> diff --git a/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir b/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir index 306534ecef..a15ca2f08c 100644 --- a/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir +++ b/test/ttmlir/Dialect/TTIR/loops/linearize_memref.mlir @@ -6,7 +6,7 @@ func.func @add(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, %arg1: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> - "ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %arg1, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>): // CHECK: = memref.collapse_shape %arg3 // CHECK: = memref.collapse_shape %arg4 @@ -27,7 +27,7 @@ func.func @add(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d func.func @addT(%arg0: memref<1x1x2x4x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>, %arg1T: memref<1x1x4x2x!tt.tile<32x32, f32>, #tt.stream<(d0, d1, d2, d3) -> (d0, d1, d2, d3), alias>, #l1_>) -> memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x2x4x!tt.tile<32x32, f32>, #l1_> - "ttir.generic"(%arg0, %arg1T, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array, operand_cb_mapping = array}> ({ + "ttir.generic"(%arg0, %arg1T, %alloc) <{grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<2x4x!tt.tile<32x32, f32>, #l1_>, %arg3: memref<4x2x!tt.tile<32x32, f32>, #l1_>, %arg4: memref<2x4x!tt.tile<32x32, f32>, #l1_>): // CHECK: = memref.collapse_shape %arg3 // CHECK: = memref.collapse_shape %arg4 diff --git a/test/ttmlir/Dialect/TTIR/ttir_generic_hoist.mlir b/test/ttmlir/Dialect/TTIR/ttir_generic_hoist.mlir index 2e5ff1ab8b..b49f9f954b 100644 --- a/test/ttmlir/Dialect/TTIR/ttir_generic_hoist.mlir +++ b/test/ttmlir/Dialect/TTIR/ttir_generic_hoist.mlir @@ -7,8 +7,7 @@ func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tens grid = #tt.grid<1x1>, indexing_maps = [#map, #map, #map], iterator_types = [#parallel, #parallel], - operandSegmentSizes = array, - operand_cb_mapping = array}> ({ + operandSegmentSizes = array}> ({ ^bb0(%arg2: memref<64x128xf32>, %arg3: memref<64x128xf32>, %arg4: memref<64x128xf32>): // lit CHECK to make sure this constant stays inside the generic region // CHECK: ttir.generic From 46612ac10a069e01137ac0325fbf849c2546008c Mon Sep 17 00:00:00 2001 From: Vlad Roubtsov Date: Wed, 5 Mar 2025 19:33:28 +0000 Subject: [PATCH 3/6] move getDpsOutputs() helper into tt::ttir and a dialect-specific Utils.h --- include/ttmlir/Dialect/TTIR/IR/TTIROps.h | 23 ++++----- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 12 ++--- include/ttmlir/Dialect/TTIR/IR/Utils.h | 59 +++++++++++++++++++++ include/ttmlir/Utils.h | 62 ++++------------------- 4 files changed, 86 insertions(+), 70 deletions(-) create mode 100644 include/ttmlir/Dialect/TTIR/IR/Utils.h diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index 62239a5000..6cba434256 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -8,18 +8,17 @@ #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.h" #include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h" - -#include "ttmlir/Utils.h" - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "ttmlir/Dialect/TTIR/IR/Utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir::tt::ttir { diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 9cffa8bfa1..0cb5c9e199 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -24,7 +24,7 @@ class TTIR_DPSOp traits = []> : TTIR_Op { let extraClassDeclaration = [{ // base implementation that will detect getOutputMutable() or getOutputsMutable() - MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } }]; } @@ -112,7 +112,7 @@ def TTIR_GenericOp : TTIR_BufferizableOp<"generic", [AttrSizedOperandSegments, N let hasVerifier = 1; let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } }]; } @@ -138,7 +138,7 @@ def TTIR_ToLayoutOp : TTIR_BufferizableOp<"to_layout"> { let results = (outs Variadic:$results); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } struct CompoundComponents { bool isLayoutChange = false; @@ -781,7 +781,7 @@ class TTIR_ReductionOp traits = []> : let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); // Returns the indexing maps and iterator types for the reduction op. @@ -1826,7 +1826,7 @@ class TTIR_GenericElementwiseUnaryOp traits = []> : TTIR_ElementwiseUnaryOp { let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { @@ -1856,7 +1856,7 @@ class TTIR_GenericElementwiseBinaryOp traits = []> TTIR_ElementwiseBinaryOp { let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } + MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); std::pair<::mlir::ArrayAttr, ::mlir::ArrayAttr> getIndexingMaps(Builder &builder) { diff --git a/include/ttmlir/Dialect/TTIR/IR/Utils.h b/include/ttmlir/Dialect/TTIR/IR/Utils.h new file mode 100644 index 0000000000..8e9dc2c711 --- /dev/null +++ b/include/ttmlir/Dialect/TTIR/IR/Utils.h @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: (c) 20245 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTIR_IR_UTILS_H +#define TTMLIR_DIALECT_TTIR_IR_UTILS_H + +#include + +#include + +namespace mlir::tt::ttir { + +// detect the presence of 'getOutputsMutable()' in 'Op': +template +inline constexpr bool has_variadic_outputs = false; + +template +inline constexpr bool has_variadic_outputs< + Op, std::void_t().getOutputsMutable())>> = true; + +namespace impl { + +template +struct getDpsOutputs { + static mlir::MutableOperandRange evaluate(Op *op) { + return op->getOutputMutable(); + } +}; + +template +struct getDpsOutputs>> { + static mlir::MutableOperandRange evaluate(Op *op) { + return op->getOutputsMutable(); + } +}; + +} // namespace impl + +// A helper for simplifying DPS tablegen derivations with 'arguments' of any +// form in {AnyRankedTensor:$output, Variadic:$outputs}. +// +// If a base tablegen 'class' adds this extra class declaration, derived 'def's +// don't need to overrride it just to switch from single to variadic type of +// '$outputs' (or vice versa): +// ... +// clang-format off +// let extraClassDeclaration = [{ +// MutableOperandRange getDpsInitsMutable() { return ttir::getDpsOutputs(this); } +// }] +// clang-format on +template +mlir::MutableOperandRange getDpsOutputs(Op *op) { + return impl::getDpsOutputs::evaluate(op); +} + +} // namespace mlir::tt::ttir + +#endif // TTMLIR_DIALECT_TTIR_IR_UTILS_H \ No newline at end of file diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index cd28d94e25..bef42610bd 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -5,19 +5,20 @@ #ifndef TTMLIR_UTILS_H #define TTMLIR_UTILS_H -#include "mlir-c/IR.h" -#include "mlir/CAPI/IR.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Error.h" +#include +#include +#include +#include +#include +#include +#include +#include #include #include namespace ttmlir::utils { + template T alignUp(T ptr, T alignment) { return (ptr + alignment - 1) & ~(alignment - 1); @@ -459,49 +460,6 @@ OpTy replaceOpWithNewDPSOp(mlir::PatternRewriter &rewriter, mlir::Operation *op, return newOp; } -// detect the presence of 'getOutputsMutable()' in 'Op': -template -inline constexpr bool has_variadic_outputs = false; - -template -inline constexpr bool has_variadic_outputs< - Op, std::void_t().getOutputsMutable())>> = true; - -namespace impl { - -template -struct getDpsOutputs { - static mlir::MutableOperandRange evaluate(Op *op) { - return op->getOutputMutable(); - } -}; - -template -struct getDpsOutputs>> { - static mlir::MutableOperandRange evaluate(Op *op) { - return op->getOutputsMutable(); - } -}; - -} // namespace impl - -// A helper for simplifying DPS tablegen derivations with 'arguments' of any -// form in {AnyRankedTensor:$output, Variadic:$outputs}. -// -// If a base tablegen 'class' adds this extra class declaration, derived 'def's -// don't need to overrride it just to switch from single to variadic type of -// '$outputs' (or vice versa): -// ... -// clang-format off -// let extraClassDeclaration = [{ -// MutableOperandRange getDpsInitsMutable() { return ::ttmlir::utils::getDpsOutputs(this); } -// }] -// clang-format on -template -mlir::MutableOperandRange getDpsOutputs(Op *op) { - return impl::getDpsOutputs::evaluate(op); -} - } // namespace ttmlir::utils -#endif +#endif // TTMLIR_UTILS_H From a169fc7da26bd509b9a4ff618d894cb8a02237f2 Mon Sep 17 00:00:00 2001 From: Vlad Roubtsov Date: Wed, 5 Mar 2025 19:48:28 +0000 Subject: [PATCH 4/6] add newline at EOF --- include/ttmlir/Dialect/TTIR/IR/Utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/Utils.h b/include/ttmlir/Dialect/TTIR/IR/Utils.h index 8e9dc2c711..63643e93d4 100644 --- a/include/ttmlir/Dialect/TTIR/IR/Utils.h +++ b/include/ttmlir/Dialect/TTIR/IR/Utils.h @@ -56,4 +56,4 @@ mlir::MutableOperandRange getDpsOutputs(Op *op) { } // namespace mlir::tt::ttir -#endif // TTMLIR_DIALECT_TTIR_IR_UTILS_H \ No newline at end of file +#endif // TTMLIR_DIALECT_TTIR_IR_UTILS_H From f14ba54b944484be09d4b8b2a1ab9e26ef77f0e6 Mon Sep 17 00:00:00 2001 From: Vlad Roubtsov Date: Wed, 5 Mar 2025 23:53:47 +0000 Subject: [PATCH 5/6] edits to address PR feedback --- include/ttmlir/Dialect/TTIR/IR/TTIROps.h | 18 +++++++++--------- include/ttmlir/Dialect/TTIR/IR/Utils.h | 8 ++++---- include/ttmlir/Utils.h | 16 ++++++++-------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index 6cba434256..694cde785c 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -10,15 +10,15 @@ #include "ttmlir/Dialect/TTIR/IR/TTIRTraits.h" #include "ttmlir/Dialect/TTIR/IR/Utils.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir::tt::ttir { diff --git a/include/ttmlir/Dialect/TTIR/IR/Utils.h b/include/ttmlir/Dialect/TTIR/IR/Utils.h index 63643e93d4..cb4e1146f8 100644 --- a/include/ttmlir/Dialect/TTIR/IR/Utils.h +++ b/include/ttmlir/Dialect/TTIR/IR/Utils.h @@ -5,7 +5,7 @@ #ifndef TTMLIR_DIALECT_TTIR_IR_UTILS_H #define TTMLIR_DIALECT_TTIR_IR_UTILS_H -#include +#include "mlir/IR/ValueRange.h" #include @@ -13,10 +13,10 @@ namespace mlir::tt::ttir { // detect the presence of 'getOutputsMutable()' in 'Op': template -inline constexpr bool has_variadic_outputs = false; +inline constexpr bool has_variadic_outputs_v = false; template -inline constexpr bool has_variadic_outputs< +inline constexpr bool has_variadic_outputs_v< Op, std::void_t().getOutputsMutable())>> = true; namespace impl { @@ -29,7 +29,7 @@ struct getDpsOutputs { }; template -struct getDpsOutputs>> { +struct getDpsOutputs>> { static mlir::MutableOperandRange evaluate(Op *op) { return op->getOutputsMutable(); } diff --git a/include/ttmlir/Utils.h b/include/ttmlir/Utils.h index bef42610bd..c99edeb284 100644 --- a/include/ttmlir/Utils.h +++ b/include/ttmlir/Utils.h @@ -5,14 +5,14 @@ #ifndef TTMLIR_UTILS_H #define TTMLIR_UTILS_H -#include -#include -#include -#include -#include -#include -#include -#include +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Error.h" #include #include From 513445eaa858567753fca222451d823c74154e51 Mon Sep 17 00:00:00 2001 From: Vlad Roubtsov Date: Thu, 6 Mar 2025 00:09:46 +0000 Subject: [PATCH 6/6] final simplification --- include/ttmlir/Dialect/TTIR/IR/Utils.h | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/Utils.h b/include/ttmlir/Dialect/TTIR/IR/Utils.h index cb4e1146f8..ba7780408d 100644 --- a/include/ttmlir/Dialect/TTIR/IR/Utils.h +++ b/include/ttmlir/Dialect/TTIR/IR/Utils.h @@ -19,24 +19,6 @@ template inline constexpr bool has_variadic_outputs_v< Op, std::void_t().getOutputsMutable())>> = true; -namespace impl { - -template -struct getDpsOutputs { - static mlir::MutableOperandRange evaluate(Op *op) { - return op->getOutputMutable(); - } -}; - -template -struct getDpsOutputs>> { - static mlir::MutableOperandRange evaluate(Op *op) { - return op->getOutputsMutable(); - } -}; - -} // namespace impl - // A helper for simplifying DPS tablegen derivations with 'arguments' of any // form in {AnyRankedTensor:$output, Variadic:$outputs}. // @@ -51,7 +33,11 @@ struct getDpsOutputs>> { // clang-format on template mlir::MutableOperandRange getDpsOutputs(Op *op) { - return impl::getDpsOutputs::evaluate(op); + if constexpr (has_variadic_outputs_v) { + return op->getOutputsMutable(); + } else { + return op->getOutputMutable(); + } } } // namespace mlir::tt::ttir