From 42719aa6274cd333ab8bc123ebc2c93b70bfcdf3 Mon Sep 17 00:00:00 2001 From: SaeHie Park Date: Wed, 26 Feb 2025 09:54:15 +0900 Subject: [PATCH] [circle-mlir/dialect] Enable AddOp IR This will enable AddOp IR in dialect. ONE-DCO-1.0-Signed-off-by: SaeHie Park --- .../circle-mlir/lib/dialect/mlir/CircleOps.td | 38 +++++++++++ .../lib/dialect/src/CircleDialect.cpp | 2 +- .../lib/dialect/src/ShapeInference.cpp | 17 ++++- .../circle-mlir/lib/dialect/src/ops/AddOp.h | 66 +++++++++++++++++++ 4 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 circle-mlir/circle-mlir/lib/dialect/src/ops/AddOp.h diff --git a/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td b/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td index 1c0989feec7..d2f07f977ca 100644 --- a/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td +++ b/circle-mlir/circle-mlir/lib/dialect/mlir/CircleOps.td @@ -296,6 +296,44 @@ class CIR_ConvOp($_op))">>, + ResultsBroadcastableShape, + DeclareOpInterfaceMethods, + Pure, + Commutative, + // TODO enable QuantizableResult, + ]> { + let summary = "Addition operator"; + + let description = [{ + Element-wise addition operation. + }]; + + let arguments = ( + // TODO add more dtypes + ins CIR_TensorOf<[F32, I32, I64]>:$lhs, + CIR_TensorOf<[F32, I32, I64]>:$rhs, + CIR_AFAttr:$fused_activation_function); + + let results = (outs CIR_TensorOf<[F32, I32, I64]>:$output); + + let hasFolder = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDefinition = [{ + ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { + return parseOneResultSameOperandTypeOp(parser, result); + } + void $cppClass::print(OpAsmPrinter &p) { + return printOneResultOp(getOperation(), p); + } + }]; + + let hasOptions = 1; +} def CIR_ConstOp : Op bool inferBinShapes(BINOP &op, SmallVector } // namespace -// TODO add AddOp +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::inferShapes() +{ + AddOp op = *this; + SmallVector inferred; + if (!inferBinShapes(op, inferred)) + return; + + auto input0_op = getOperand(0); + auto input0_type = input0_op.getType().cast(); + RankedTensorType inferred_type = RankedTensorType::get(inferred, input0_type.getElementType()); + getResult().setType(inferred_type); +} //===----------------------------------------------------------------------===// // CustomOp diff --git a/circle-mlir/circle-mlir/lib/dialect/src/ops/AddOp.h b/circle-mlir/circle-mlir/lib/dialect/src/ops/AddOp.h new file mode 100644 index 00000000000..89ea40fe1c9 --- /dev/null +++ b/circle-mlir/circle-mlir/lib/dialect/src/ops/AddOp.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// from tensorflow/compiler/mlir/lite/ir/tfl_ops.cc + +#ifndef __CIRCLE_MLIR_DIALECT_OPS_ADD_OP_H__ +#define __CIRCLE_MLIR_DIALECT_OPS_ADD_OP_H__ + +#include "circle-mlir/dialect/CircleDialect.h" + +namespace mlir +{ +namespace Circle +{ + +// Return true if the given Add operation has the CPU kernel supported shapes. +bool VerifyAddOpShapeConstraints(AddOp op) +{ + auto element_type = getElementTypeOrSelf(op.getOutput().getType()); + + // Allows F32 and I32 outputs when the operands have valid shapes, + // which are broadcastable shapes up to four dimensions or have same shapes. + // TODO support Quantized Type + if (element_type.isF32() || IsI32Type(element_type) || IsI64Type(element_type)) + { + return VerifyOperandsHaveSameShapesOrBroadcastableShape( + /*op=*/op.getOperation(), /*indices=*/ArrayRef{0, 1}, + /*max_bcast_rank=*/4); + } + + return false; +} + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddOp::fold(FoldAdaptor adaptor) +{ + auto operands = adaptor.getOperands(); + // TODO(b/142478136): Handle fused ops. + if (getFusedActivationFunction() != "NONE") + return {}; + return ConstFoldBinaryOp( + getType(), operands, [](APFloat a, APFloat b) { return a + b; }, + [](APInt a, APInt b) { return a + b; }); +} + +} // namespace Circle +} // namespace mlir + +#endif // __CIRCLE_MLIR_DIALECT_OPS_ADD_OP_H__