Skip to content

Commit 05a8f47

Browse files
bsoyluogluDannyYuyang-quic
authored andcommitted
Move ExecutorchRuntime to xplat
Differential Revision: D70826477 Pull Request resolved: pytorch#9281
1 parent 34971fb commit 05a8f47

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
@_implementationOnly import ExecutorchRuntimeBridge
10+
@_implementationOnly import ExecutorchRuntimeValueSupport
11+
import Foundation
12+
import ModelRunnerDataKit
13+
14+
public class ExecutorchRuntime: ModelRuntime {
15+
private let engine: ExecutorchRuntimeEngine
16+
public init(modelPath: String, modelMethodName: String) throws {
17+
self.engine = try ExecutorchRuntimeEngine(modelPath: modelPath, modelMethodName: modelMethodName)
18+
}
19+
public func infer(input: [ModelRuntimeValue]) throws -> [ModelRuntimeValue] {
20+
let modelInput = input.compactMap { $0.value as? ExecutorchRuntimeValue }
21+
// Not all values were of type ExecutorchRuntimeValue
22+
guard input.count == modelInput.count else {
23+
throw ModelRuntimeError.unsupportedInputType
24+
}
25+
return try engine.infer(input: modelInput).compactMap { ModelRuntimeValue(innerValue: $0) }
26+
}
27+
28+
public func getModelValueFactory() -> ModelRuntimeValueFactory {
29+
return ExecutorchRuntimeValueSupport()
30+
}
31+
public func getModelTensorFactory() -> ModelRuntimeTensorValueFactory {
32+
return ExecutorchRuntimeValueSupport()
33+
}
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
@testable import ExecutorchRuntime
10+
import ExecutorchRuntimeValueSupport
11+
import XCTest
12+
13+
class ExecutorchRuntimeTests: XCTestCase {
14+
func testRuntimeWithAddPTE() throws {
15+
let bundle = Bundle(for: type(of: self))
16+
let modelPath = try XCTUnwrap(bundle.path(forResource: "add", ofType: "pte"))
17+
let runtime = try XCTUnwrap(ExecutorchRuntime(modelPath: modelPath, modelMethodName: "forward"))
18+
19+
let tensorInput = try XCTUnwrap(runtime.getModelTensorFactory().createFloatTensor(value: [2.0], shape: [1]))
20+
let input = try runtime.getModelValueFactory().createTensor(value: tensorInput)
21+
22+
let output = try XCTUnwrap(runtime.infer(input: [input, input]))
23+
24+
let tensorOutput = try output.first?.tensorValue().floatRepresentation()
25+
XCTAssertEqual(tensorOutput?.floatArray.count, 1)
26+
XCTAssertEqual(tensorOutput?.shape.count, 1)
27+
XCTAssertEqual(tensorOutput?.shape.first, 1)
28+
XCTAssertEqual(tensorOutput?.floatArray.first, 4.0)
29+
}
30+
}

0 commit comments

Comments
 (0)