@@ -7065,9 +7065,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
7065
7065
Torch::ListType::get (Torch::IntType::get (op.getContext ()));
7066
7066
Value sizeList =
7067
7067
rewriter.create <AtenSizeOp>(op.getLoc (), sizeListType, op.getSelf ());
7068
+
7069
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
7070
+ if (failed (dtype)) {
7071
+ return rewriter.notifyMatchFailure (
7072
+ op, " could not determine dtype from the op." );
7073
+ }
7074
+
7068
7075
rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
7069
- op, op.getType (), sizeList, op.getDtype (), op.getLayout (),
7070
- op.getDevice (), op. getPinMemory (), op.getMemoryFormat ());
7076
+ op, op.getType (), sizeList, *dtype, op.getLayout (), op.getDevice (),
7077
+ op.getPinMemory (), op.getMemoryFormat ());
7071
7078
return success ();
7072
7079
}
7073
7080
};
@@ -7816,18 +7823,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
7816
7823
LogicalResult matchAndRewrite (AtenNewEmptyOp op,
7817
7824
PatternRewriter &rewriter) const override {
7818
7825
Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
7819
- Value dtype = op.getDtype ();
7820
- if (isa<Torch::NoneType>(dtype.getType ())) {
7821
- BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf ().getType ());
7822
- if (!tensorType.hasDtype ()) {
7823
- return rewriter.notifyMatchFailure (
7824
- op, " expected input tensor to have a dtype" );
7825
- }
7826
- dtype =
7827
- getDtypeIntValueForType (rewriter, op.getLoc (), tensorType.getDtype ());
7826
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
7827
+ if (failed (dtype)) {
7828
+ return rewriter.notifyMatchFailure (
7829
+ op, " could not determine dtype from the op." );
7828
7830
}
7829
7831
rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
7830
- op, op.getType (), op.getSize (), dtype, op.getLayout (), op.getDevice (),
7832
+ op, op.getType (), op.getSize (), * dtype, op.getLayout (), op.getDevice (),
7831
7833
op.getPinMemory (), /* memoryFormat=*/ noneVal);
7832
7834
return success ();
7833
7835
}
@@ -9286,12 +9288,12 @@ class DecomposeAtenRandnGeneratorOp
9286
9288
Location loc = op.getLoc ();
9287
9289
auto resultType = cast<BaseTensorType>(op.getType ());
9288
9290
9289
- if (!resultType.hasDtype ()) {
9291
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9292
+ if (failed (dtype)) {
9290
9293
return rewriter.notifyMatchFailure (
9291
- op, " expected result type to have a dtype " );
9294
+ op, " could not determine dtype from the op. " );
9292
9295
}
9293
9296
9294
- Value dtype = getDtypeIntValueForType (rewriter, loc, resultType.getDtype ());
9295
9297
Value none = rewriter.create <ConstantNoneOp>(loc);
9296
9298
Value low = rewriter.create <Torch::ConstantFloatOp>(
9297
9299
loc, rewriter.getF64FloatAttr ((double )0.0 ));
@@ -9303,12 +9305,12 @@ class DecomposeAtenRandnGeneratorOp
9303
9305
loc, rewriter.getF64FloatAttr ((double )(2.0 * 3.14159 )));
9304
9306
9305
9307
Value emptyTensorA = rewriter.create <AtenEmptyMemoryFormatOp>(
9306
- loc, resultType, op.getSize (), /* dtype=*/ dtype,
9308
+ loc, resultType, op.getSize (), /* dtype=*/ * dtype,
9307
9309
/* layout=*/ op.getLayout (),
9308
9310
/* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
9309
9311
/* memory_format=*/ none);
9310
9312
Value emptyTensorB = rewriter.create <AtenEmptyMemoryFormatOp>(
9311
- loc, resultType, op.getSize (), /* dtype=*/ dtype,
9313
+ loc, resultType, op.getSize (), /* dtype=*/ * dtype,
9312
9314
/* layout=*/ op.getLayout (),
9313
9315
/* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
9314
9316
/* memory_format=*/ none);
@@ -9406,8 +9408,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
9406
9408
loc, rewriter.getF64FloatAttr ((double )0.0 ));
9407
9409
Value high = rewriter.create <Torch::ConstantFloatOp>(
9408
9410
loc, rewriter.getF64FloatAttr ((double )1.0 ));
9411
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9412
+ if (failed (dtype)) {
9413
+ return rewriter.notifyMatchFailure (
9414
+ op, " could not determine dtype from the op." );
9415
+ }
9409
9416
Value emptyTensor = rewriter.create <AtenEmptyMemoryFormatOp>(
9410
- loc, resultType, op.getSize (), /* dtype=*/ op. getDtype () ,
9417
+ loc, resultType, op.getSize (), /* dtype=*/ *dtype ,
9411
9418
/* layout=*/ op.getLayout (),
9412
9419
/* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
9413
9420
/* memory_format=*/ noneVal);
@@ -9565,9 +9572,14 @@ class DecomposeAtenEmptyStridedOp
9565
9572
9566
9573
Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
9567
9574
9575
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9576
+ if (failed (dtype)) {
9577
+ return rewriter.notifyMatchFailure (
9578
+ op, " could not determine dtype from the op." );
9579
+ }
9568
9580
rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
9569
- op, op.getType (), op.getSize (), op.getDtype (), op.getLayout (),
9570
- op.getDevice (), op. getPinMemory (), /* memoryFormat=*/ noneVal);
9581
+ op, op.getType (), op.getSize (), *dtype, op.getLayout (), op.getDevice (),
9582
+ op.getPinMemory (), /* memoryFormat=*/ noneVal);
9571
9583
return success ();
9572
9584
}
9573
9585
};
0 commit comments