Skip to content

Commit 71af168

Browse files
committed
[AutoBump] Merge with fixes of 481da8d (Jan 22)
2 parents a0fedd1 + 481da8d commit 71af168

File tree

4 files changed

+87
-12
lines changed

4 files changed

+87
-12
lines changed

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
308308
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
309309
Value src, Type destType, Value &result) {
310310

311-
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
311+
TensorType srcType = dyn_cast<TensorType>(src.getType());
312+
Type srcElemTy = srcType.getElementType();
312313
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
313314

314315
// Temporarily disable checkValidityOfCast as it's currently strictly
@@ -370,6 +371,23 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
370371
result = rewriter.create<tosa::LogicalNotOp>(op->getLoc(), destType,
371372
equalToZero);
372373
} else {
374+
if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger()) {
375+
// for float->int conversion, tosa.cast performs round-to-nearest
376+
// torch performs round-to-zero instead
377+
// generate round-to-zero conversion prior to tosa.cast to match with
378+
// expected torch behavior
379+
auto floor = rewriter.create<tosa::FloorOp>(op->getLoc(), srcType, src);
380+
auto ceil = rewriter.create<tosa::CeilOp>(op->getLoc(), srcType, src);
381+
382+
auto zeroValue =
383+
tosa::getConstTensor<float>(rewriter, op, 0, {}, srcElemTy).value();
384+
385+
auto boolType = srcType.clone(rewriter.getIntegerType(1));
386+
auto isNegative = tosa::CreateOpAndInfer<tosa::GreaterOp>(
387+
rewriter, op->getLoc(), boolType, zeroValue, src);
388+
src = tosa::CreateOpAndInfer<tosa::SelectOp>(
389+
rewriter, op->getLoc(), srcType, isNegative, ceil, floor);
390+
}
373391
result = rewriter.create<tosa::CastOp>(op->getLoc(), destType, src);
374392
}
375393
return success();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,8 @@
18361836
"TriuIndicesNegativeOffsetModule_basic",
18371837
"BmmFloat16Module_basic",
18381838
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
1839+
"LinspaceDtypeModule_basic",
1840+
"Aten_CastLongModule_basic",
18391841
"Unfold_Module_Rank_4",
18401842
"Unfold_Module_Rank_Zero_basic",
18411843
"Unfold_Module_basic",
@@ -2743,6 +2745,7 @@
27432745
}
27442746

27452747
ONNX_XFAIL_SET = {
2748+
"ToDtypeIntFromFloatModule_basic",
27462749
# This test is expected to time out
27472750
"TimeOutModule_basic",
27482751
# Failure - cast error
@@ -3456,6 +3459,7 @@
34563459
}
34573460

34583461
FX_IMPORTER_TOSA_XFAIL_SET = {
3462+
"ScatterAddDynamicModule_basic",
34593463
"UniformModule_basic",
34603464
"UniformStaticShapeModule_basic",
34613465
"AtenFftRfft2DLastDim_basic",
@@ -3560,7 +3564,6 @@
35603564
"AtenSubFloatModule_basic",
35613565
"AtenTopKModule_basic",
35623566
"AtenTopKSmallestModule_basic",
3563-
"Aten_CastLongModule_basic",
35643567
"Aten_EmbeddingBagExample_basic",
35653568
"AvgPool1dFloatModule_basic",
35663569
"AvgPool1dIntModule_basic",
@@ -3633,7 +3636,6 @@
36333636
"ConvolutionModule2DTransposeStridedStatic_basic",
36343637
"ConvolutionModule2DTransposeStrided_basic",
36353638
"ConvolutionModule2DTranspose_basic",
3636-
"CopyWithDifferentDTypesModule_basic",
36373639
"CumsumModule_basic",
36383640
"CumprodModule_basic",
36393641
"CumprodInputDtypeInt32Module_basic",
@@ -3679,7 +3681,6 @@
36793681
"ElementwiseQuantizePerTensorUIntModule_basic",
36803682
"ElementwiseSinhIntModule_basic",
36813683
"ElementwiseSinhModule_basic",
3682-
"ElementwiseToDtypeF32ToI64Module_basic",
36833684
"ElementwiseToDtypeI64ToUI8Module_basic",
36843685
"ElementwiseSignbitModule_basic",
36853686
"EmbeddingModule1DIndices_basic",
@@ -3715,8 +3716,6 @@
37153716
"IndexPutImpl2DNoneIndexStaticModule_basic",
37163717
"IndexPutImpl3DFloatAccumulateModule_basic",
37173718
"IndexPutImplIndexWithNoneModule_basic",
3718-
"InterpolateDynamicModule_sizes_bilinear",
3719-
"InterpolateDynamicModule_scales_recompute_bilinear",
37203719
"IntFloatModule_basic",
37213720
"IntImplicitModule_basic",
37223721
"IsFloatingPointFloat_True",
@@ -3728,7 +3727,6 @@
37283727
"LenStrModule_basic",
37293728
"LinalgNormKeepDimComplexModule_basic",
37303729
"LinalgVectorNormComplexModule_basic",
3731-
"LinspaceDtypeModule_basic",
37323730
"MaskedScatterStaticBasic_basic",
37333731
"MaxPool1dCeilModeTrueModule_basic",
37343732
"MaxPool1dModule_basic",
@@ -3791,7 +3789,6 @@
37913789
"PrimMaxIntModule_basic",
37923790
"PrimMinIntDynamicModule_basic",
37933791
"PrimMinIntModule_basic",
3794-
"PrimsConvertElementTypeModule_basic",
37953792
"PrimsSqueezeEmptyDimensionsModule_basic",
37963793
"PrimsSqueezeModule_basic",
37973794
"PrimsViewOfModule_basic",
@@ -3880,8 +3877,6 @@
38803877
"TensorToInt_basic",
38813878
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
38823879
"ThresholdBackward2dMixedModule_basic",
3883-
"ToCopyWithDTypeFalsePinMemoryModule_basic",
3884-
"ToCopyWithDTypeModule_basic",
38853880
"TorchPrimLoopForLikeModule_basic",
38863881
"TorchPrimLoopWhileLikeModule_basic",
38873882
"TraceModule_empty",
@@ -3997,7 +3992,6 @@
39973992
}
39983993
# Failing on stable but not on nightly
39993994
FX_IMPORTER_TOSA_XFAIL_SET |= {
4000-
"InterpolateDynamicModule_sizes_nearest",
40013995
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
40023996
"ElementwiseRreluWithNoiseTrainModule_basic",
40033997
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
@@ -4240,7 +4234,6 @@
42404234
"AtenTriuModule_basic",
42414235
"AtenTriuWithNegDiagonalModule_basic",
42424236
"AtenTriuWithPosDiagonalModule_basic",
4243-
"Aten_CastLongModule_basic",
42444237
"Aten_EmbeddingBagExample_basic",
42454238
"AvgPool1dFloatModule_basic",
42464239
"AvgPool1dIntModule_basic",
@@ -4955,6 +4948,8 @@
49554948
"ToDtypeLayoutCPUModule_basic",
49564949
"ToDtypeLayoutNoneModule_basic",
49574950
"ToDtypeLayoutStridedModule_basic",
4951+
"ToDtypeIntFromFloatModule_basic",
4952+
"ToDtypeFloatFromIntModule_basic",
49584953
"TorchPrimLoopForLikeModule_basic",
49594954
"TorchPrimLoopWhileLikeModule_basic",
49604955
"TraceModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,45 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils):
255255
module.forward(tu.randint(3, 5))
256256

257257

258+
class ToDtypeFloatFromIntModule(torch.nn.Module):
259+
def __init__(self):
260+
super().__init__()
261+
262+
@export
263+
@annotate_args([None, ([-1, -1], torch.int64, True)])
264+
def forward(self, x):
265+
return torch.ops.aten.to(
266+
x,
267+
dtype=torch.float32,
268+
)
269+
270+
271+
@register_test_case(module_factory=lambda: ToDtypeFloatFromIntModule())
272+
def ToDtypeFloatFromIntModule_basic(module, tu: TestUtils):
273+
input = torch.randint(low=-5, high=5, size=(2, 2)).to(torch.int64)
274+
module.forward(input)
275+
276+
277+
class ToDtypeIntFromFloatModule(torch.nn.Module):
278+
def __init__(self):
279+
super().__init__()
280+
281+
@export
282+
@annotate_args([None, ([-1, -1], torch.float64, True)])
283+
def forward(self, x):
284+
return torch.ops.aten.to(
285+
x,
286+
dtype=torch.int64,
287+
)
288+
289+
290+
@register_test_case(module_factory=lambda: ToDtypeIntFromFloatModule())
291+
def ToDtypeIntFromFloatModule_basic(module, tu: TestUtils):
292+
input = tu.rand(2, 2, low=-5, high=5)
293+
input[1][1] = tu.randint(1, 1) + 0.7
294+
module.forward(input)
295+
296+
258297
class TypeAsSameModule(torch.nn.Module):
259298
def __init__(self):
260299
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,29 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten
12351235
return %0 : !torch.vtensor<[1,128],si64>
12361236
}
12371237

1238+
// -----
1239+
// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt(
1240+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
1241+
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
1242+
// CHECK: %[[INT4:.*]] = torch.constant.int 4
1243+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
1244+
// CHECK: %[[NONE:.*]] = torch.constant.none
1245+
// CHECK: %[[FLOOR:.*]] = tosa.floor %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
1246+
// CHECK: %[[CEIL:.*]] = tosa.ceil %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
1247+
// CHECK: %[[F0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
1248+
// CHECK: %[[IS_NEG:.*]] = tosa.greater %[[F0]], %[[TENSOR]] : (tensor<f32>, tensor<3x5xf32>) -> tensor<3x5xi1>
1249+
// CHECK: %[[SELECT:.*]] = tosa.select %[[IS_NEG]], %[[CEIL]], %[[FLOOR]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
1250+
// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<3x5xf32>) -> tensor<3x5xi64>
1251+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64>
1252+
// CHECK: return %[[RES]] : !torch.vtensor<[3,5],si64>
1253+
func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
1254+
%int4 = torch.constant.int 4
1255+
%false = torch.constant.bool false
1256+
%none = torch.constant.none
1257+
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],si64>
1258+
return %0 : !torch.vtensor<[3,5],si64>
1259+
}
1260+
12381261
// -----
12391262
// CHECK-LABEL: func.func @torch.aten.gather(
12401263
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,

0 commit comments

Comments
 (0)