@@ -7065,9 +7065,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
70657065 Torch::ListType::get (Torch::IntType::get (op.getContext ()));
70667066 Value sizeList =
70677067 rewriter.create <AtenSizeOp>(op.getLoc (), sizeListType, op.getSelf ());
7068+
7069+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
7070+ if (failed (dtype)) {
7071+ return rewriter.notifyMatchFailure (
7072+ op, " could not determine dtype from the op." );
7073+ }
7074+
70687075 rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
7069- op, op.getType (), sizeList, op.getDtype (), op.getLayout (),
7070- op.getDevice (), op. getPinMemory (), op.getMemoryFormat ());
7076+ op, op.getType (), sizeList, *dtype, op.getLayout (), op.getDevice (),
7077+ op.getPinMemory (), op.getMemoryFormat ());
70717078 return success ();
70727079 }
70737080};
@@ -7816,18 +7823,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
78167823 LogicalResult matchAndRewrite (AtenNewEmptyOp op,
78177824 PatternRewriter &rewriter) const override {
78187825 Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
7819- Value dtype = op.getDtype ();
7820- if (isa<Torch::NoneType>(dtype.getType ())) {
7821- BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf ().getType ());
7822- if (!tensorType.hasDtype ()) {
7823- return rewriter.notifyMatchFailure (
7824- op, " expected input tensor to have a dtype" );
7825- }
7826- dtype =
7827- getDtypeIntValueForType (rewriter, op.getLoc (), tensorType.getDtype ());
7826+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
7827+ if (failed (dtype)) {
7828+ return rewriter.notifyMatchFailure (
7829+ op, " could not determine dtype from the op." );
78287830 }
78297831 rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
7830- op, op.getType (), op.getSize (), dtype, op.getLayout (), op.getDevice (),
7832+ op, op.getType (), op.getSize (), * dtype, op.getLayout (), op.getDevice (),
78317833 op.getPinMemory (), /* memoryFormat=*/ noneVal);
78327834 return success ();
78337835 }
@@ -9286,12 +9288,12 @@ class DecomposeAtenRandnGeneratorOp
92869288 Location loc = op.getLoc ();
92879289 auto resultType = cast<BaseTensorType>(op.getType ());
92889290
9289- if (!resultType.hasDtype ()) {
9291+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9292+ if (failed (dtype)) {
92909293 return rewriter.notifyMatchFailure (
9291- op, " expected result type to have a dtype " );
9294+ op, " could not determine dtype from the op. " );
92929295 }
92939296
9294- Value dtype = getDtypeIntValueForType (rewriter, loc, resultType.getDtype ());
92959297 Value none = rewriter.create <ConstantNoneOp>(loc);
92969298 Value low = rewriter.create <Torch::ConstantFloatOp>(
92979299 loc, rewriter.getF64FloatAttr ((double )0.0 ));
@@ -9303,12 +9305,12 @@ class DecomposeAtenRandnGeneratorOp
93039305 loc, rewriter.getF64FloatAttr ((double )(2.0 * 3.14159 )));
93049306
93059307 Value emptyTensorA = rewriter.create <AtenEmptyMemoryFormatOp>(
9306- loc, resultType, op.getSize (), /* dtype=*/ dtype,
9308+ loc, resultType, op.getSize (), /* dtype=*/ * dtype,
93079309 /* layout=*/ op.getLayout (),
93089310 /* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
93099311 /* memory_format=*/ none);
93109312 Value emptyTensorB = rewriter.create <AtenEmptyMemoryFormatOp>(
9311- loc, resultType, op.getSize (), /* dtype=*/ dtype,
9313+ loc, resultType, op.getSize (), /* dtype=*/ * dtype,
93129314 /* layout=*/ op.getLayout (),
93139315 /* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
93149316 /* memory_format=*/ none);
@@ -9406,8 +9408,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
94069408 loc, rewriter.getF64FloatAttr ((double )0.0 ));
94079409 Value high = rewriter.create <Torch::ConstantFloatOp>(
94089410 loc, rewriter.getF64FloatAttr ((double )1.0 ));
9411+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9412+ if (failed (dtype)) {
9413+ return rewriter.notifyMatchFailure (
9414+ op, " could not determine dtype from the op." );
9415+ }
94099416 Value emptyTensor = rewriter.create <AtenEmptyMemoryFormatOp>(
9410- loc, resultType, op.getSize (), /* dtype=*/ op. getDtype () ,
9417+ loc, resultType, op.getSize (), /* dtype=*/ *dtype ,
94119418 /* layout=*/ op.getLayout (),
94129419 /* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
94139420 /* memory_format=*/ noneVal);
@@ -9565,9 +9572,14 @@ class DecomposeAtenEmptyStridedOp
95659572
95669573 Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
95679574
9575+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9576+ if (failed (dtype)) {
9577+ return rewriter.notifyMatchFailure (
9578+ op, " could not determine dtype from the op." );
9579+ }
95689580 rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
9569- op, op.getType (), op.getSize (), op.getDtype (), op.getLayout (),
9570- op.getDevice (), op. getPinMemory (), /* memoryFormat=*/ noneVal);
9581+ op, op.getType (), op.getSize (), *dtype, op.getLayout (), op.getDevice (),
9582+ op.getPinMemory (), /* memoryFormat=*/ noneVal);
95719583 return success ();
95729584 }
95739585};
0 commit comments