Skip to content

Commit

Permalink
Merge pull request #546 from Xilinx/bump_to_040aec90
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of 040aec9 (Jan 14) (153)
  • Loading branch information
mgehre-amd authored Feb 14, 2025
2 parents 7dbb08a + c654228 commit 9f5cca3
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

// Check for global seed and create if it doesn't exist.
auto module = op->getParentOfType<ModuleOp>();
OpBuilder b(module.getBodyRegion());
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
return failure();

// Generate sequence for getting the next seed with LCG step:
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
Expand Down Expand Up @@ -115,11 +122,6 @@ class ConvertTorchConversionToMLProgram
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

auto module = getOperation();
OpBuilder b(module.getBodyRegion());
if (failed(getOrCreateGlobalVariableForSeed(b, module)))
signalPassFailure();

RewritePatternSet patterns(context);
target.addIllegalOp<GetNextSeedOp>();
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
Expand Down
13 changes: 0 additions & 13 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,6 @@ std::optional<Value> getConstTensor<double>(PatternRewriter &rewriter,
return const_op.getResult();
}

static LogicalResult checkValidityOfCast(Type src, Type dest) {
if (src == dest)
return success();

auto isValid = [](Type ty) {
return ty.isInteger(1) || ty.isInteger(8) || ty.isInteger(16) ||
ty.isInteger(32) || ty.isInteger(64) || ty.isBF16() || ty.isF16() ||
ty.isF32() || ty.isF64() || ty.isFloat8E4M3() || ty.isFloat8E5M2();
};

return success(isValid(src) && isValid(dest));
}

// Template specialization for float
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Value src, Type destType, Value &result) {
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchConversionToMLProgram/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ module {
return %seed : i64
}
}

// -----

module {
func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
return %0 : !torch.vtensor<[2,3],f32>
}
}

// CHECK-NOT: ml_program.global
// CHECK-LABEL: @no_seed_needed
// CHECK-NEXT: torch_c.from_builtin_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ module {
func.func private @f7() -> i64
}

// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-NOT: @global_seed

This file was deleted.

0 comments on commit 9f5cca3

Please sign in to comment.