-
Notifications
You must be signed in to change notification settings - Fork 168
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change-Id: Iffc03399a9b21294e3b6f4ae15f83c11881e8966
- Loading branch information
1 parent
bea91c4
commit 2de83ca
Showing
6 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. | ||
// | ||
// TPU-MLIR is licensed under the 2-Clause BSD License except for the | ||
// third-party components. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "tpu_mlir/Dialect/Top/IR/TopOps.h" | ||
#include "tpu_mlir/Support/MathUtils.h" | ||
#include "tpu_mlir/Support/Module.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
|
||
using namespace tpu_mlir::top; | ||
|
||
struct TransposeToPermutePattern : public OpRewritePattern<TransposeOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(TransposeOp op, | ||
PatternRewriter &rewriter) const override { | ||
const auto &input = op.getInput(); | ||
auto dim0_ = op.getDim0(); | ||
auto dim1_ = op.getDim1(); | ||
auto dims = module::getShape(input).size(); | ||
if (dims < 2) | ||
return failure(); | ||
std::vector<int64_t> order; | ||
for (int i = 0; i < dims; ++i) { | ||
if (dim0_ == i) { | ||
order.push_back(dim1_); | ||
} else if (dim1_ == i) { | ||
order.push_back(dim0_); | ||
} else { | ||
order.push_back(i); | ||
} | ||
} | ||
// rewrite | ||
std::vector<NamedAttribute> attrs; | ||
attrs.push_back( | ||
rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order))); | ||
rewriter.replaceOpWithNewOp<PermuteOp>( | ||
op, op.getResult().getType(), ValueRange{input}, attrs); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct TransposeFussPattern : public OpRewritePattern<TransposeOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(TransposeOp op, | ||
PatternRewriter &rewriter) const override { | ||
const auto &input = op.getInput(); | ||
auto dims = module::getShape(input).size(); | ||
if (dims >= 2) | ||
return failure(); | ||
op.getOutput().replaceAllUsesWith(input); | ||
return success(); | ||
} | ||
}; | ||
|
||
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.insert<TransposeToPermutePattern, TransposeFussPattern>(context); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright (C) 2022 Sophgo Technologies Inc. All rights reserved. | ||
// | ||
// TPU-MLIR is licensed under the 2-Clause BSD License except for the | ||
// third-party components. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "tpu_mlir/Dialect/Top/IR/TopOps.h" | ||
#include "tpu_mlir/Support/MathUtils.h" | ||
#include "tpu_mlir/Support/Module.h" | ||
|
||
int64_t top::TransposeOp::getFLOPs() { return 0; } | ||
|
||
LogicalResult top::TransposeOp::init(InferenceParameter &p) {return success();} | ||
|
||
void top::TransposeOp::deinit(InferenceParameter &p) {} | ||
|
||
LogicalResult top::TransposeOp::inference(InferenceParameter &p) { | ||
llvm_unreachable("Not Implemented"); | ||
} | ||
|
||
void top::TransposeOp::shape_inference() { | ||
auto dim0_ = getDim0(); | ||
auto dim1_ = getDim1(); | ||
auto in_shape = module::getShape(getInput()); | ||
std::vector<int64_t> out_shape(in_shape); | ||
if (in_shape.size() >= 2) { | ||
out_shape[dim0_] = in_shape[dim1_]; | ||
out_shape[dim1_] = in_shape[dim0_]; | ||
} | ||
module::setShapeOrVerify(getOutput(), out_shape); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters