Skip to content

Commit 8b465e3

Browse files
authored
Merge pull request #544 from Xilinx/bump_to_9a167e2d
[AutoBump] Merge with fixes of 9a167e2 (Jan 10) (151)
2 parents d426c0a + 8b59c18 commit 8b465e3

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5710,11 +5710,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
57105710
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
57115711
}
57125712

5713-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
5714-
op, resultTy,
5715-
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
5716-
// op.getType()),
5717-
result);
5713+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);
57185714

57195715
return success();
57205716
}
@@ -6648,11 +6644,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
66486644
tosa::getConstTensor<int32_t>(rewriter, op,
66496645
/*vec=*/{0, 3, 1, 2},
66506646
/*shape=*/{static_cast<int32_t>(4)});
6651-
// SmallVector<int64_t> transposedOutputShape(
6652-
// {transposedResizedOpShape[0], transposedResizedOpShape[3],
6653-
// transposedResizedOpShape[1], transposedResizedOpShape[2]});
6654-
// auto transposedOutputType = RankedTensorType::get(
6655-
// makeShapeLLVMCompatible(transposedOutputShape), inputElemTy);
6647+
66566648
rewriter
66576649
.replaceOpWithNewOp<tosa::TransposeOp>(
66586650
op, getTypeConverter()->convertType(resultType), resizeOpResult,

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
311311
auto isValid = [](Type ty) {
312312
return ty.isInteger(1) || ty.isInteger(8) || ty.isInteger(16) ||
313313
ty.isInteger(32) || ty.isInteger(64) || ty.isBF16() || ty.isF16() ||
314-
ty.isF32() || ty.isF64();
314+
ty.isF32() || ty.isF64() || ty.isFloat8E4M3() || ty.isFloat8E5M2();
315315
};
316316

317317
return success(isValid(src) && isValid(dest));
@@ -324,9 +324,17 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
324324
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
325325
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
326326

327-
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
328-
return rewriter.notifyMatchFailure(
329-
op, "casting to result dtype is invalid or unsupported");
327+
// Temporarily disable checkValidityOfCast as it's currently strictly
328+
// following TOSA spec and might cause many e2e tests to fail. This is because
329+
// even though there are some casting pairs that are not congruent to TOSA
330+
// spec, they are still permissible. TOSA validation should flag these illegal
331+
// constructs in a per-profile manner. This strict validity check will be
332+
// enabled later in a potential `--strict` mode which checks for strict
333+
// casting only when needed (the default value of `--strict` mode will be
334+
// off).
335+
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
336+
// return rewriter.notifyMatchFailure(
337+
// op, "casting to result dtype is invalid or unsupported");
330338

331339
if (destElemTy.isInteger(1)) {
332340
auto srcType = dyn_cast<TensorType>(src.getType());

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,21 @@
18211821
# Write the TOSA set as a "passing" set as it is very early in development
18221822
# and very few tests work yet.
18231823
TOSA_PASS_SET = {
1824+
"AtenEyeMModuleInt2D_basic",
1825+
"AtenEyeModuleInt2D_basic",
1826+
"ElementwiseWhereScalarOtherStaticModule_basic",
1827+
"FullModuleFalsePinMemory_basic",
1828+
"FullModuleInt2D_basic",
1829+
"MaskedFillScalarFloatValueModule_basic",
1830+
"MaskedFillScalarFloatValueStaticModule_basic",
1831+
"NewFullModuleInt2D_basic",
1832+
"NewFullModuleInt3D_basic",
1833+
"Threshold3dIntModule_basic",
1834+
"TrilIndicesModule_basic",
1835+
"TrilIndicesOfssetGreaterThanRowModule_basic",
1836+
"TriuIndicesNegativeOffsetModule_basic",
1837+
"BmmFloat16Module_basic",
1838+
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
18241839
"Unfold_Module_Rank_4",
18251840
"Unfold_Module_Rank_Zero_basic",
18261841
"Unfold_Module_basic",
@@ -2662,6 +2677,8 @@
26622677
}
26632678
) - {
26642679
### Test failing in make_fx_tosa but not in tosa
2680+
"ElementwiseRreluEvalStaticModule_basic",
2681+
"ElementwiseRreluTrainStaticModule_basic",
26652682
"AdaptiveMaxPool1dDimOneStatic_basic",
26662683
"FloatPowerTensorTensorStaticModule_basic",
26672684
# Dynamic shape, has extra unsupported broadcast ops
@@ -4871,7 +4888,6 @@
48714888
"QuantizedReluUint8_basic",
48724889
"QuantizedSingleLayer_basic",
48734890
"RandIntDtypeModule_basic",
4874-
"RandIntLowDtypeModule_basic",
48754891
"RandIntModule_basic",
48764892
"RandIntPinMemoryModule_basic",
48774893
"RandLikeDtypeModule_basic",

0 commit comments

Comments
 (0)