diff --git a/include/tpu_mlir/Dialect/Top/IR/TopOps.td b/include/tpu_mlir/Dialect/Top/IR/TopOps.td index 1f6ead192..c7290798f 100755 --- a/include/tpu_mlir/Dialect/Top/IR/TopOps.td +++ b/include/tpu_mlir/Dialect/Top/IR/TopOps.td @@ -650,6 +650,23 @@ def Top_PermuteOp: Top_Op<"Permute"> { }]; } +def Top_TransposeOp: Top_Op<"Transpose"> { + + let summary = "Transpose operator"; + + let description = [{ + Transpose on input. + }]; + + let arguments = ( + ins AnyTensor:$input, + I64Attr:$dim0, + I64Attr:$dim1 + ); + let results = (outs AnyTensor:$output); + let hasCanonicalizer = 1; +} + def Top_ShuffleChannelOp: Top_Op<"ShuffleChannel"> { let summary = "ShuffleChannel operator"; diff --git a/lib/Dialect/Top/Canonicalize/Transpose.cpp b/lib/Dialect/Top/Canonicalize/Transpose.cpp new file mode 100644 index 000000000..94af94735 --- /dev/null +++ b/lib/Dialect/Top/Canonicalize/Transpose.cpp @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. +// +// TPU-MLIR is licensed under the 2-Clause BSD License except for the +// third-party components. +// +//===----------------------------------------------------------------------===// + +#include "tpu_mlir/Dialect/Top/IR/TopOps.h" +#include "tpu_mlir/Support/MathUtils.h" +#include "tpu_mlir/Support/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace tpu_mlir::top; + +struct TransposeToPermutePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter &rewriter) const override { + const auto &input = op.getInput(); + auto dim0_ = op.getDim0(); + auto dim1_ = op.getDim1(); + auto dims = module::getShape(input).size(); + if (dims < 2) + return failure(); + std::vector order; + for (int i = 0; i < dims; ++i) { + if (dim0_ == i) { + order.push_back(dim1_); + } else if (dim1_ == i) { + order.push_back(dim0_); + } else { + order.push_back(i); + } + } + // rewrite + std::vector attrs; + attrs.push_back( + rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order))); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), ValueRange{input}, attrs); + return success(); + } +}; + +struct TransposeFussPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter &rewriter) const override { + const auto &input = op.getInput(); + auto dims = module::getShape(input).size(); + if (dims >= 2) + return failure(); + op.getOutput().replaceAllUsesWith(input); + return success(); + } +}; + +void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/lib/Dialect/Top/Interfaces/Transpose.cpp b/lib/Dialect/Top/Interfaces/Transpose.cpp new file mode 100644 index 000000000..ad96ea74a --- /dev/null +++ b/lib/Dialect/Top/Interfaces/Transpose.cpp @@ -0,0 +1,34 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. +// +// TPU-MLIR is licensed under the 2-Clause BSD License except for the +// third-party components. +// +//===----------------------------------------------------------------------===// + +#include "tpu_mlir/Dialect/Top/IR/TopOps.h" +#include "tpu_mlir/Support/MathUtils.h" +#include "tpu_mlir/Support/Module.h" + +int64_t top::TransposeOp::getFLOPs() { return 0; } + +LogicalResult top::TransposeOp::init(InferenceParameter &p) {return success();} + +void top::TransposeOp::deinit(InferenceParameter &p) {} + +LogicalResult top::TransposeOp::inference(InferenceParameter &p) { + llvm_unreachable("Not Implemented"); +} + +void top::TransposeOp::shape_inference() { + auto dim0_ = getDim0(); + auto dim1_ = getDim1(); + auto in_shape = module::getShape(getInput()); + std::vector out_shape(in_shape); + if (in_shape.size() >= 2) { + out_shape[dim0_] = in_shape[dim1_]; + out_shape[dim1_] = in_shape[dim0_]; + } + module::setShapeOrVerify(getOutput(), out_shape); +} diff --git a/python/test/test_torch.py b/python/test/test_torch.py index 06e67b097..b78b5e89b 100755 --- a/python/test/test_torch.py +++ b/python/test/test_torch.py @@ -42,6 +42,8 @@ def __init__(self, chip: str = "bm1684x", mode: str = "all"): "PRelu": self.test_PRelu, "Permute": self.test_Permute, "Sub": self.test_Sub, + "T": self.test_T, + "Transpose": self.test_Transpose, } self.support_quant_modes = ["f32", "f16", "bf16"] #self.support_quant_modes = ["f32", "f16", "bf16", "int8", "int4"] @@ -448,6 +450,53 @@ def forward(self, x): _test_permute((2, 32, 16), (2, 0, 1)) _test_permute((32, 32), (1, 0)) + ####################################################################### + # T + # ------------ + def test_T(self, case_name): + """T""" + + def _test_t(in_shape): + + class Model(nn.Module): + + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + x = torch.concat((x, x)) + x = torch.t(x) + y1 = torch.concat((x, x)) + return y1 + + self.convert_torch_and_compare([in_shape], case_name, Model().eval()) + + _test_t((32, 32)) + _test_t((32, )) + + ####################################################################### + # Transpose + # ------------ + def test_Transpose(self, case_name): + """Transpose""" + + def _test_transpose(in_shape, dims): + + class Model(nn.Module): + + def __init__(self): + super(Model, self).__init__() + + def forward(self, x): + y1 = torch.transpose(x, dim0=dims[0], dim1=dims[1]) + return y1 + + self.convert_torch_and_compare([in_shape], case_name, Model().eval()) + + _test_transpose((1, 3, 32, 32), (0, 3)) + _test_transpose((2, 32, 16), (2, 0)) + _test_transpose((32, 32), (1, 0)) + ####################################################################### # Concat # ------------ diff --git a/python/transform/MLIRImporter.py b/python/transform/MLIRImporter.py index bcce84961..35df17989 100755 --- a/python/transform/MLIRImporter.py +++ b/python/transform/MLIRImporter.py @@ -86,6 +86,7 @@ class Top: TileOp = 'top.Tile' TanhOp = 'top.Tanh' TopKOp = 'top.TopK' + TransposeOp = 'top.Transpose' UnpackOp = 'top.Unpack' UpsampleOp = 'top.Upsample' WeightOp = 'top.Weight' @@ -457,6 +458,16 @@ def create_permute_op(self, operands, output_shape, **kargs): } return self.buildOp(Top.PermuteOp, operands, [output_type], **param) + def create_transpose_op(self, operands, output_shape, **kargs): + # get_value_type + output_type = self.get_tensor_type(output_shape) + param = { + 'name': kargs['name'], + 'dim0': IntegerAttr.get(self.mlir_type['INT64'], kargs['dim0']), + 'dim1': IntegerAttr.get(self.mlir_type['INT64'], kargs['dim1']), + } + return self.buildOp(Top.TransposeOp, operands, [output_type], **param) + def create_matmul_op(self, operands, output_shape, **kargs): # get_value_type output_type = self.get_tensor_type(output_shape) diff --git a/python/transform/TorchConverter.py b/python/transform/TorchConverter.py index fbf53f3e0..30067f842 100644 --- a/python/transform/TorchConverter.py +++ b/python/transform/TorchConverter.py @@ -146,6 +146,8 @@ def __init__(self, "aten::prelu": lambda node: self.convert_prelu_op(node), "aten::permute": lambda node: self.convert_permute_op(node), "aten::sub": lambda node: self.convert_sub_op(node), + "aten::t": lambda node: self.convert_transpose_op(node), + "aten::transpose": lambda node: self.convert_transpose_op(node), } self.check_op_names() @@ -433,6 +435,19 @@ def convert_permute_op(self, torch_node: TorchNode): new_op = self.mlir.create_permute_op([op], None, **p) self.addOperand(torch_node.name, new_op) + def convert_transpose_op(self, torch_node: TorchNode): + op = self.getOp(torch_node.inputs[0]) + no_dims = len(torch_node.inputs) == 1 + dim0 = self.const_val[torch_node.inputs[1]] if not no_dims else 0 + dim1 = self.const_val[torch_node.inputs[2]] if not no_dims else 1 + p = { + 'name': torch_node.name, + 'dim0': dim0, + 'dim1': dim1, + } + new_op = self.mlir.create_transpose_op([op], None, **p) + self.addOperand(torch_node.name, new_op) + def convert_layer_norm_op(self, torch_node: TorchNode): op0 = self.getOp(torch_node.inputs[0]) scale_opd = self.mlir.none_op