Skip to content

Commit ced0ad5

Browse files
committed
[Task] : Use op.dtype to create EmptyMemorFormat during decomposition.
1 parent fd65a66 commit ced0ad5

File tree

5 files changed

+131
-20
lines changed

5 files changed

+131
-20
lines changed

include/torch-mlir/Dialect/Torch/Utils/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Type getTypeForTorchType(
3636
MLIRContext *context, Type type,
3737
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);
3838

39+
template <typename OpTy>
40+
FailureOr<Value> getDtypeFromOp(PatternRewriter &rewriter, OpTy op);
41+
3942
FailureOr<Type> getTorchTypeForScalarType(MLIRContext *context,
4043
torch_upstream::ScalarType dtypeInt);
4144

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7065,9 +7065,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
70657065
Torch::ListType::get(Torch::IntType::get(op.getContext()));
70667066
Value sizeList =
70677067
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());
7068+
7069+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
7070+
if (failed(dtype)) {
7071+
return rewriter.notifyMatchFailure(
7072+
op, "could not determine dtype from the op.");
7073+
}
7074+
70687075
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
7069-
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
7070-
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
7076+
op, op.getType(), sizeList, *dtype, op.getLayout(), op.getDevice(),
7077+
op.getPinMemory(), op.getMemoryFormat());
70717078
return success();
70727079
}
70737080
};
@@ -7816,18 +7823,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
78167823
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
78177824
PatternRewriter &rewriter) const override {
78187825
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
7819-
Value dtype = op.getDtype();
7820-
if (isa<Torch::NoneType>(dtype.getType())) {
7821-
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
7822-
if (!tensorType.hasDtype()) {
7823-
return rewriter.notifyMatchFailure(
7824-
op, "expected input tensor to have a dtype");
7825-
}
7826-
dtype =
7827-
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
7826+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
7827+
if (failed(dtype)) {
7828+
return rewriter.notifyMatchFailure(
7829+
op, "could not determine dtype from the op.");
78287830
}
78297831
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
7830-
op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(),
7832+
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
78317833
op.getPinMemory(), /*memoryFormat=*/noneVal);
78327834
return success();
78337835
}
@@ -9286,12 +9288,12 @@ class DecomposeAtenRandnGeneratorOp
92869288
Location loc = op.getLoc();
92879289
auto resultType = cast<BaseTensorType>(op.getType());
92889290

9289-
if (!resultType.hasDtype()) {
9291+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
9292+
if (failed(dtype)) {
92909293
return rewriter.notifyMatchFailure(
9291-
op, "expected result type to have a dtype");
9294+
op, "could not determine dtype from the op.");
92929295
}
92939296

9294-
Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
92959297
Value none = rewriter.create<ConstantNoneOp>(loc);
92969298
Value low = rewriter.create<Torch::ConstantFloatOp>(
92979299
loc, rewriter.getF64FloatAttr((double)0.0));
@@ -9303,12 +9305,12 @@ class DecomposeAtenRandnGeneratorOp
93039305
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));
93049306

93059307
Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
9306-
loc, resultType, op.getSize(), /*dtype=*/dtype,
9308+
loc, resultType, op.getSize(), /*dtype=*/*dtype,
93079309
/*layout=*/op.getLayout(),
93089310
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
93099311
/*memory_format=*/none);
93109312
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
9311-
loc, resultType, op.getSize(), /*dtype=*/dtype,
9313+
loc, resultType, op.getSize(), /*dtype=*/*dtype,
93129314
/*layout=*/op.getLayout(),
93139315
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
93149316
/*memory_format=*/none);
@@ -9406,8 +9408,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
94069408
loc, rewriter.getF64FloatAttr((double)0.0));
94079409
Value high = rewriter.create<Torch::ConstantFloatOp>(
94089410
loc, rewriter.getF64FloatAttr((double)1.0));
9411+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
9412+
if (failed(dtype)) {
9413+
return rewriter.notifyMatchFailure(
9414+
op, "could not determine dtype from the op.");
9415+
}
94099416
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
9410-
loc, resultType, op.getSize(), /*dtype=*/op.getDtype(),
9417+
loc, resultType, op.getSize(), /*dtype=*/*dtype,
94119418
/*layout=*/op.getLayout(),
94129419
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
94139420
/*memory_format=*/noneVal);
@@ -9565,9 +9572,14 @@ class DecomposeAtenEmptyStridedOp
95659572

95669573
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
95679574

9575+
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
9576+
if (failed(dtype)) {
9577+
return rewriter.notifyMatchFailure(
9578+
op, "could not determine dtype from the op.");
9579+
}
95689580
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
9569-
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
9570-
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
9581+
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
9582+
op.getPinMemory(), /*memoryFormat=*/noneVal);
95719583
return success();
95729584
}
95739585
};

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,36 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
237237
rewriter.getI64IntegerAttr(intType));
238238
}
239239

240+
template <typename OpTy>
241+
FailureOr<Value> Torch::getDtypeFromOp(PatternRewriter &rewriter, OpTy op) {
242+
Value dtype = op.getDtype();
243+
if (isa<Torch::NoneType>(dtype.getType())) {
244+
BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
245+
if (!tensorType.hasDtype()) {
246+
return rewriter.notifyMatchFailure(
247+
op, "expected input tensor to have a dtype");
248+
}
249+
dtype =
250+
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
251+
}
252+
return dtype;
253+
}
254+
// Template instantiation template std::optional<Value>
255+
template FailureOr<Value>
256+
Torch::getDtypeFromOp<AtenEmptyLikeOp>(PatternRewriter &rewriter,
257+
AtenEmptyLikeOp op);
258+
template FailureOr<Value>
259+
Torch::getDtypeFromOp<AtenNewEmptyOp>(PatternRewriter &rewriter,
260+
AtenNewEmptyOp op);
261+
template FailureOr<Value>
262+
Torch::getDtypeFromOp<AtenRandOp>(PatternRewriter &rewriter, AtenRandOp op);
263+
template FailureOr<Value>
264+
Torch::getDtypeFromOp<AtenEmptyStridedOp>(PatternRewriter &rewriter,
265+
AtenEmptyStridedOp op);
266+
template FailureOr<Value>
267+
Torch::getDtypeFromOp<AtenRandnGeneratorOp>(PatternRewriter &rewriter,
268+
AtenRandnGeneratorOp op);
269+
240270
// Helper to convert a tensor to a specific scalar type.
241271
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
242272
Value input, Type dtype) {

projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,26 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
641641
module.forward(tu.rand(2, 3, 4))
642642

643643

644+
class EmptyLikeDefaultDtypeFloat64InputModule(torch.nn.Module):
645+
def __init__(self):
646+
super().__init__()
647+
648+
@export
649+
@annotate_args(
650+
[
651+
None,
652+
([-1, -1, -1], torch.float64, True),
653+
]
654+
)
655+
def forward(self, x):
656+
return torch.empty_like(x).fill_(0)
657+
658+
659+
@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeFloat64InputModule())
660+
def EmptyLikeDefaultDtypeFloat64InputModule_basic(module, tu: TestUtils):
661+
module.forward(torch.ones((200, 200, 26), dtype=torch.float64))
662+
663+
644664
# ==============================================================================
645665

646666

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,49 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int {
278278
torch.aten._assert_scalar %3, %str : !torch.int, !torch.str
279279
return %arg0 : !torch.int
280280
}
281+
// CHECK-LABEL: func.func @emptyLikeNoneDtype(
282+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
283+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
284+
// CHECK: %[[NONE:.*]] = torch.constant.none
285+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
286+
// CHECK: %[[C200:.*]] = torch.constant.int 200
287+
// CHECK: %[[C26:.*]] = torch.constant.int 26
288+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
289+
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
290+
func.func @emptyLikeNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
291+
%none = torch.constant.none
292+
%none_0 = torch.constant.none
293+
%none_1 = torch.constant.none
294+
%false = torch.constant.bool false
295+
%none_2 = torch.constant.none
296+
%0 = torch.aten.empty_like %arg0, %none, %none_0, %none_1, %false, %none_2 : !torch.vtensor<[200,200,26],f64>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
297+
return %0 : !torch.vtensor<[200,200,26],f64>
298+
}
299+
300+
// -----
301+
// CHECK-LABEL: func.func @randNoneDtype(
302+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
303+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
304+
// CHECK: %[[C1:.*]] = torch.constant.float 1.000000e+00
305+
// CHECK: %[[C0:.*]] = torch.constant.float 0.000000e+00
306+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
307+
// CHECK: %[[NONE:.*]] = torch.constant.none
308+
// CHECK: %[[C200:.*]] = torch.constant.int 200
309+
// CHECK: %[[C26:.*]] = torch.constant.int 26
310+
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
311+
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
312+
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[CPU]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
313+
// CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[MEM_FMT]], %[[C0]], %[[C1]], %[[NONE]] : !torch.vtensor<[200,200,26],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[200,200,26],f64>
314+
// CHECK: return %[[UNIFORM]] : !torch.vtensor<[200,200,26],f64>
315+
func.func @randNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
316+
%int200 = torch.constant.int 200
317+
%int200_0 = torch.constant.int 200
318+
%int26 = torch.constant.int 26
319+
%0 = torch.prim.ListConstruct %int200, %int200_0, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
320+
%none = torch.constant.none
321+
%none_1 = torch.constant.none
322+
%cpu = torch.constant.device "cpu"
323+
%false = torch.constant.bool false
324+
%1 = torch.aten.rand %0, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[200,200,26],f64>
325+
return %1 : !torch.vtensor<[200,200,26],f64>
326+
}

0 commit comments

Comments
 (0)