Skip to content

Commit 9f5cca3

Browse files
authored
Merge pull request #546 from Xilinx/bump_to_040aec90
[AutoBump] Merge with fixes of 040aec9 (Jan 14) (153)
2 parents 7dbb08a + c654228 commit 9f5cca3

File tree

5 files changed

+21
-36
lines changed

5 files changed

+21
-36
lines changed

lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
5959
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
6060
ConversionPatternRewriter &rewriter) const override {
6161
Location loc = op.getLoc();
62+
63+
// Check for global seed and create if it doesn't exist.
64+
auto module = op->getParentOfType<ModuleOp>();
65+
OpBuilder b(module.getBodyRegion());
66+
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
67+
return failure();
68+
6269
// Generate sequence for getting the next seed with LCG step:
6370
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
6471
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
@@ -115,11 +122,6 @@ class ConvertTorchConversionToMLProgram
115122
typeConverter.addConversion([](Type type) { return type; });
116123
TorchConversion::setupBackendTypeConversion(target, typeConverter);
117124

118-
auto module = getOperation();
119-
OpBuilder b(module.getBodyRegion());
120-
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
121-
signalPassFailure();
122-
123125
RewritePatternSet patterns(context);
124126
target.addIllegalOp<GetNextSeedOp>();
125127
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,19 +304,6 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
304304
return const_op.getResult();
305305
}
306306

307-
static LogicalResult checkValidityOfCast(Type src, Type dest) {
308-
if (src == dest)
309-
return success();
310-
311-
auto isValid = [](Type ty) {
312-
return ty.isInteger(1) || ty.isInteger(8) || ty.isInteger(16) ||
313-
ty.isInteger(32) || ty.isInteger(64) || ty.isBF16() || ty.isF16() ||
314-
ty.isF32() || ty.isF64() || ty.isFloat8E4M3() || ty.isFloat8E5M2();
315-
};
316-
317-
return success(isValid(src) && isValid(dest));
318-
}
319-
320307
// Template specialization for float
321308
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
322309
Value src, Type destType, Value &result) {

test/Conversion/TorchConversionToMLProgram/basic.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,16 @@ module {
1717
return %seed : i64
1818
}
1919
}
20+
21+
// -----
22+
23+
module {
24+
func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> {
25+
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
26+
return %0 : !torch.vtensor<[2,3],f32>
27+
}
28+
}
29+
30+
// CHECK-NOT: ml_program.global
31+
// CHECK-LABEL: @no_seed_needed
32+
// CHECK-NEXT: torch_c.from_builtin_tensor

test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ module {
1111
func.func private @f7() -> i64
1212
}
1313

14-
// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
14+
// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
1515
// CHECK-NOT: @global_seed

test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-mlprogram.mlir

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)