Skip to content

Commit

Permalink
Replace div op with combination of recip and mul
Browse files Browse the repository at this point in the history
  • Loading branch information
mrakitaTT committed Sep 28, 2024
1 parent ddee2bd commit ec7d32e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 20 deletions.
77 changes: 57 additions & 20 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,24 +490,6 @@ class ConstantOpConversionPattern
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
class MatmulOpConversionPattern : public OpConversionPattern<ttir::MatmulOp> {
public:
using OpConversionPattern<ttir::MatmulOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::MatmulOp>(
op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
adaptor.getB(), adaptor.getOutput());
return success();
}
};
// ANCHOR_END: adding_an_op_matmul_op_rewriter

class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
public:
using OpConversionPattern<ttir::Conv2dOp>::OpConversionPattern;
Expand Down Expand Up @@ -647,6 +629,61 @@ class BroadcastOpConversionPattern
}
};

// TODO(issue #841): ttnn.div doesn't currently support implicit broadcast of
// inputs (tt-metal/issues/12798) so we are using custom DivOpConversionPattern
// to replace DivOp with combination of ReciprocalOp + MultiplyOp. Revert this
// once the issue on Metal side is fixed.
class DivOpConversionPattern : public OpConversionPattern<ttir::DivOp> {
using OpConversionPattern<ttir::DivOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(ttir::DivOp srcOp, ttir::DivOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType lhsType =
mlir::cast<RankedTensorType>(adaptor.getInputs().front().getType());
RankedTensorType rhsType =
mlir::cast<RankedTensorType>(adaptor.getInputs().back().getType());

if (lhsType.getShape() == rhsType.getShape()) {
rewriter.replaceOpWithNewOp<ttnn::DivOp>(
srcOp, adaptor.getInputs().front(), adaptor.getInputs().back(),
adaptor.getOutputs().front());
} else {
Value device = getOrInsertDevice(rewriter, srcOp);
tensor::EmptyOp recipEmptyOp = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), this->getTypeConverter()->convertType(rhsType),
device);
ttnn::ReciprocalOp recipOp = rewriter.create<ttnn::ReciprocalOp>(
srcOp.getLoc(), adaptor.getInputs().back(), recipEmptyOp);

rewriter.replaceOpWithNewOp<ttnn::MultiplyOp>(
srcOp, adaptor.getInputs().front(), recipOp.getResults().front(),
adaptor.getOutputs().front());
}

return success();
}
};

} // namespace

// ANCHOR: adding_an_op_matmul_op_rewriter
class MatmulOpConversionPattern : public OpConversionPattern<ttir::MatmulOp> {
public:
using OpConversionPattern<ttir::MatmulOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::MatmulOp>(
op, this->getTypeConverter()->convertType(op.getType()), adaptor.getA(),
adaptor.getB(), adaptor.getOutput());
return success();
}
};
// ANCHOR_END: adding_an_op_matmul_op_rewriter

namespace mlir::tt {

void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down Expand Up @@ -675,7 +712,6 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::TypecastOp, ttnn::TypecastOp>,
ElementwiseOpConversionPattern<ttir::ReciprocalOp, ttnn::ReciprocalOp>,
ElementwiseOpConversionPattern<ttir::ExpOp, ttnn::ExpOp>,
ElementwiseOpConversionPattern<ttir::DivOp, ttnn::DivOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
Expand All @@ -690,7 +726,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ConstantOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern
MaxPool2dOpConversionPattern,
DivOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
// clang-format on
Expand Down
13 changes: 13 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_div_with_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<1x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[MULTIPLY_OUT:[0-9]+]] = "ttnn.empty"{{.+}} -> tensor<64x128xf32, {{.+}}
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[RECIP_OUT:[0-9]+]] = "ttnn.empty"{{.+}} -> tensor<1x128xf32, {{.+}}
// CHECK: %[[RECIP_RESULT:[0-9]+]] = "ttnn.reciprocal"(%{{[0-9]+}}, %[[RECIP_OUT]]){{.+}} -> tensor<1x128xf32, {{.+}}
// CHECK: %{{[0-9]+}} = "ttnn.multiply"(%{{[0-9]+}}, %[[RECIP_RESULT]], %[[MULTIPLY_OUT]]){{.+}} -> tensor<64x128xf32, {{.+}}
%1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<1x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}

0 comments on commit ec7d32e

Please sign in to comment.