Skip to content

Commit

Permalink
[Task] : Use op.dtype to create EmptyMemorFormat during decomposition.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahas3 committed Feb 5, 2025
1 parent fd65a66 commit ced0ad5
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 20 deletions.
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ Type getTypeForTorchType(
MLIRContext *context, Type type,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

template <typename OpTy>
FailureOr<Value> getDtypeFromOp(PatternRewriter &rewriter, OpTy op);

FailureOr<Type> getTorchTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt);

Expand Down
52 changes: 32 additions & 20 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7065,9 +7065,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());

FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}

rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
op, op.getType(), sizeList, *dtype, op.getLayout(), op.getDevice(),
op.getPinMemory(), op.getMemoryFormat());
return success();
}
};
Expand Down Expand Up @@ -7816,18 +7823,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
PatternRewriter &rewriter) const override {
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dtype = op.getDtype();
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(),
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
return success();
}
Expand Down Expand Up @@ -9286,12 +9288,12 @@ class DecomposeAtenRandnGeneratorOp
Location loc = op.getLoc();
auto resultType = cast<BaseTensorType>(op.getType());

if (!resultType.hasDtype()) {
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
op, "could not determine dtype from the op.");
}

Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
Value none = rewriter.create<ConstantNoneOp>(loc);
Value low = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)0.0));
Expand All @@ -9303,12 +9305,12 @@ class DecomposeAtenRandnGeneratorOp
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));

Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/dtype,
loc, resultType, op.getSize(), /*dtype=*/*dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/dtype,
loc, resultType, op.getSize(), /*dtype=*/*dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Expand Down Expand Up @@ -9406,8 +9408,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
loc, rewriter.getF64FloatAttr((double)0.0));
Value high = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)1.0));
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/op.getDtype(),
loc, resultType, op.getSize(), /*dtype=*/*dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/noneVal);
Expand Down Expand Up @@ -9565,9 +9572,14 @@ class DecomposeAtenEmptyStridedOp

Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());

FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
return success();
}
};
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,36 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
rewriter.getI64IntegerAttr(intType));
}

template <typename OpTy>
FailureOr<Value> Torch::getDtypeFromOp(PatternRewriter &rewriter, OpTy op) {
Value dtype = op.getDtype();
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
return dtype;
}
// Template instantiation template std::optional<Value>
template FailureOr<Value>
Torch::getDtypeFromOp<AtenEmptyLikeOp>(PatternRewriter &rewriter,
AtenEmptyLikeOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenNewEmptyOp>(PatternRewriter &rewriter,
AtenNewEmptyOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenRandOp>(PatternRewriter &rewriter, AtenRandOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenEmptyStridedOp>(PatternRewriter &rewriter,
AtenEmptyStridedOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenRandnGeneratorOp>(PatternRewriter &rewriter,
AtenRandnGeneratorOp op);

// Helper to convert a tensor to a specific scalar type.
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Value input, Type dtype) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,26 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))


class EmptyLikeDefaultDtypeFloat64InputModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float64, True),
]
)
def forward(self, x):
return torch.empty_like(x).fill_(0)


@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeFloat64InputModule())
def EmptyLikeDefaultDtypeFloat64InputModule_basic(module, tu: TestUtils):
module.forward(torch.ones((200, 200, 26), dtype=torch.float64))


# ==============================================================================


Expand Down
46 changes: 46 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,49 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int {
torch.aten._assert_scalar %3, %str : !torch.int, !torch.str
return %arg0 : !torch.int
}
// CHECK-LABEL: func.func @emptyLikeNoneDtype(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[C200:.*]] = torch.constant.int 200
// CHECK: %[[C26:.*]] = torch.constant.int 26
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
func.func @emptyLikeNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
%none = torch.constant.none
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%false = torch.constant.bool false
%none_2 = torch.constant.none
%0 = torch.aten.empty_like %arg0, %none, %none_0, %none_1, %false, %none_2 : !torch.vtensor<[200,200,26],f64>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
return %0 : !torch.vtensor<[200,200,26],f64>
}

// -----
// CHECK-LABEL: func.func @randNoneDtype(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[C1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[C0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C200:.*]] = torch.constant.int 200
// CHECK: %[[C26:.*]] = torch.constant.int 26
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[CPU]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
// CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[MEM_FMT]], %[[C0]], %[[C1]], %[[NONE]] : !torch.vtensor<[200,200,26],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[200,200,26],f64>
// CHECK: return %[[UNIFORM]] : !torch.vtensor<[200,200,26],f64>
func.func @randNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
%int200 = torch.constant.int 200
%int200_0 = torch.constant.int 200
%int26 = torch.constant.int 26
%0 = torch.prim.ListConstruct %int200, %int200_0, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%none_1 = torch.constant.none
%cpu = torch.constant.device "cpu"
%false = torch.constant.bool false
%1 = torch.aten.rand %0, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[200,200,26],f64>
return %1 : !torch.vtensor<[200,200,26],f64>
}

0 comments on commit ced0ad5

Please sign in to comment.