Skip to content

Commit

Permalink
[Torch] add torch transpose&t op
Browse files Browse the repository at this point in the history
Change-Id: Iffc03399a9b21294e3b6f4ae15f83c11881e8966
  • Loading branch information
Silence-Zhang-beijng committed Mar 8, 2023
1 parent bea91c4 commit 2de83ca
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/tpu_mlir/Dialect/Top/IR/TopOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,23 @@ def Top_PermuteOp: Top_Op<"Permute"> {
}];
}

def Top_TransposeOp: Top_Op<"Transpose"> {

let summary = "Transpose operator";

let description = [{
Transpose on input.
}];

let arguments = (
ins AnyTensor:$input,
I64Attr:$dim0,
I64Attr:$dim1
);
let results = (outs AnyTensor:$output);
let hasCanonicalizer = 1;
}

def Top_ShuffleChannelOp: Top_Op<"ShuffleChannel"> {
let summary = "ShuffleChannel operator";

Expand Down
66 changes: 66 additions & 0 deletions lib/Dialect/Top/Canonicalize/Transpose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===----------------------------------------------------------------------===//
//
// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved.
//
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
// third-party components.
//
//===----------------------------------------------------------------------===//

#include "tpu_mlir/Dialect/Top/IR/TopOps.h"
#include "tpu_mlir/Support/MathUtils.h"
#include "tpu_mlir/Support/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

using namespace tpu_mlir::top;

struct TransposeToPermutePattern : public OpRewritePattern<TransposeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp op,
PatternRewriter &rewriter) const override {
const auto &input = op.getInput();
auto dim0_ = op.getDim0();
auto dim1_ = op.getDim1();
auto dims = module::getShape(input).size();
if (dims < 2)
return failure();
std::vector<int64_t> order;
for (int i = 0; i < dims; ++i) {
if (dim0_ == i) {
order.push_back(dim1_);
} else if (dim1_ == i) {
order.push_back(dim0_);
} else {
order.push_back(i);
}
}
// rewrite
std::vector<NamedAttribute> attrs;
attrs.push_back(
rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order)));
rewriter.replaceOpWithNewOp<PermuteOp>(
op, op.getResult().getType(), ValueRange{input}, attrs);
return success();
}
};

struct TransposeFussPattern : public OpRewritePattern<TransposeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp op,
PatternRewriter &rewriter) const override {
const auto &input = op.getInput();
auto dims = module::getShape(input).size();
if (dims >= 2)
return failure();
op.getOutput().replaceAllUsesWith(input);
return success();
}
};

void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<TransposeToPermutePattern, TransposeFussPattern>(context);
}
34 changes: 34 additions & 0 deletions lib/Dialect/Top/Interfaces/Transpose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===----------------------------------------------------------------------===//
//
// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved.
//
// TPU-MLIR is licensed under the 2-Clause BSD License except for the
// third-party components.
//
//===----------------------------------------------------------------------===//

#include "tpu_mlir/Dialect/Top/IR/TopOps.h"
#include "tpu_mlir/Support/MathUtils.h"
#include "tpu_mlir/Support/Module.h"

int64_t top::TransposeOp::getFLOPs() { return 0; }

LogicalResult top::TransposeOp::init(InferenceParameter &p) {return success();}

void top::TransposeOp::deinit(InferenceParameter &p) {}

LogicalResult top::TransposeOp::inference(InferenceParameter &p) {
llvm_unreachable("Not Implemented");
}

void top::TransposeOp::shape_inference() {
auto dim0_ = getDim0();
auto dim1_ = getDim1();
auto in_shape = module::getShape(getInput());
std::vector<int64_t> out_shape(in_shape);
if (in_shape.size() >= 2) {
out_shape[dim0_] = in_shape[dim1_];
out_shape[dim1_] = in_shape[dim0_];
}
module::setShapeOrVerify(getOutput(), out_shape);
}
49 changes: 49 additions & 0 deletions python/test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, chip: str = "bm1684x", mode: str = "all"):
"PRelu": self.test_PRelu,
"Permute": self.test_Permute,
"Sub": self.test_Sub,
"T": self.test_T,
"Transpose": self.test_Transpose,
}
self.support_quant_modes = ["f32", "f16", "bf16"]
#self.support_quant_modes = ["f32", "f16", "bf16", "int8", "int4"]
Expand Down Expand Up @@ -448,6 +450,53 @@ def forward(self, x):
_test_permute((2, 32, 16), (2, 0, 1))
_test_permute((32, 32), (1, 0))

#######################################################################
# T
# ------------
def test_T(self, case_name):
"""T"""

def _test_t(in_shape):

class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()

def forward(self, x):
x = torch.concat((x, x))
x = torch.t(x)
y1 = torch.concat((x, x))
return y1

self.convert_torch_and_compare([in_shape], case_name, Model().eval())

_test_t((32, 32))
_test_t((32, ))

#######################################################################
# Transpose
# ------------
def test_Transpose(self, case_name):
"""Transpose"""

def _test_transpose(in_shape, dims):

class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()

def forward(self, x):
y1 = torch.transpose(x, dim0=dims[0], dim1=dims[1])
return y1

self.convert_torch_and_compare([in_shape], case_name, Model().eval())

_test_transpose((1, 3, 32, 32), (0, 3))
_test_transpose((2, 32, 16), (2, 0))
_test_transpose((32, 32), (1, 0))

#######################################################################
# Concat
# ------------
Expand Down
11 changes: 11 additions & 0 deletions python/transform/MLIRImporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Top:
TileOp = 'top.Tile'
TanhOp = 'top.Tanh'
TopKOp = 'top.TopK'
TransposeOp = 'top.Transpose'
UnpackOp = 'top.Unpack'
UpsampleOp = 'top.Upsample'
WeightOp = 'top.Weight'
Expand Down Expand Up @@ -457,6 +458,16 @@ def create_permute_op(self, operands, output_shape, **kargs):
}
return self.buildOp(Top.PermuteOp, operands, [output_type], **param)

def create_transpose_op(self, operands, output_shape, **kargs):
# get_value_type
output_type = self.get_tensor_type(output_shape)
param = {
'name': kargs['name'],
'dim0': IntegerAttr.get(self.mlir_type['INT64'], kargs['dim0']),
'dim1': IntegerAttr.get(self.mlir_type['INT64'], kargs['dim1']),
}
return self.buildOp(Top.TransposeOp, operands, [output_type], **param)

def create_matmul_op(self, operands, output_shape, **kargs):
# get_value_type
output_type = self.get_tensor_type(output_shape)
Expand Down
15 changes: 15 additions & 0 deletions python/transform/TorchConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(self,
"aten::prelu": lambda node: self.convert_prelu_op(node),
"aten::permute": lambda node: self.convert_permute_op(node),
"aten::sub": lambda node: self.convert_sub_op(node),
"aten::t": lambda node: self.convert_transpose_op(node),
"aten::transpose": lambda node: self.convert_transpose_op(node),
}
self.check_op_names()

Expand Down Expand Up @@ -433,6 +435,19 @@ def convert_permute_op(self, torch_node: TorchNode):
new_op = self.mlir.create_permute_op([op], None, **p)
self.addOperand(torch_node.name, new_op)

def convert_transpose_op(self, torch_node: TorchNode):
op = self.getOp(torch_node.inputs[0])
no_dims = len(torch_node.inputs) == 1
dim0 = self.const_val[torch_node.inputs[1]] if not no_dims else 0
dim1 = self.const_val[torch_node.inputs[2]] if not no_dims else 1
p = {
'name': torch_node.name,
'dim0': dim0,
'dim1': dim1,
}
new_op = self.mlir.create_transpose_op([op], None, **p)
self.addOperand(torch_node.name, new_op)

def convert_layer_norm_op(self, torch_node: TorchNode):
op0 = self.getOp(torch_node.inputs[0])
scale_opd = self.mlir.none_op
Expand Down

0 comments on commit 2de83ca

Please sign in to comment.