File tree 2 files changed +64
-0
lines changed
extension/apple/ExecutorchRuntime/ExecutorchRuntime
2 files changed +64
-0
lines changed Original file line number Diff line number Diff line change
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ @_implementationOnly import ExecutorchRuntimeBridge
10
+ @_implementationOnly import ExecutorchRuntimeValueSupport
11
+ import Foundation
12
+ import ModelRunnerDataKit
13
+
14
+ public class ExecutorchRuntime : ModelRuntime {
15
+ private let engine : ExecutorchRuntimeEngine
16
+ public init ( modelPath: String , modelMethodName: String ) throws {
17
+ self . engine = try ExecutorchRuntimeEngine ( modelPath: modelPath, modelMethodName: modelMethodName)
18
+ }
19
+ public func infer( input: [ ModelRuntimeValue ] ) throws -> [ ModelRuntimeValue ] {
20
+ let modelInput = input. compactMap { $0. value as? ExecutorchRuntimeValue }
21
+ // Not all values were of type ExecutorchRuntimeValue
22
+ guard input. count == modelInput. count else {
23
+ throw ModelRuntimeError . unsupportedInputType
24
+ }
25
+ return try engine. infer ( input: modelInput) . compactMap { ModelRuntimeValue ( innerValue: $0) }
26
+ }
27
+
28
+ public func getModelValueFactory( ) -> ModelRuntimeValueFactory {
29
+ return ExecutorchRuntimeValueSupport ( )
30
+ }
31
+ public func getModelTensorFactory( ) -> ModelRuntimeTensorValueFactory {
32
+ return ExecutorchRuntimeValueSupport ( )
33
+ }
34
+ }
Original file line number Diff line number Diff line change
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ @testable import ExecutorchRuntime
10
+ import ExecutorchRuntimeValueSupport
11
+ import XCTest
12
+
13
+ class ExecutorchRuntimeTests : XCTestCase {
14
+ func testRuntimeWithAddPTE( ) throws {
15
+ let bundle = Bundle ( for: type ( of: self ) )
16
+ let modelPath = try XCTUnwrap ( bundle. path ( forResource: " add " , ofType: " pte " ) )
17
+ let runtime = try XCTUnwrap ( ExecutorchRuntime ( modelPath: modelPath, modelMethodName: " forward " ) )
18
+
19
+ let tensorInput = try XCTUnwrap ( runtime. getModelTensorFactory ( ) . createFloatTensor ( value: [ 2.0 ] , shape: [ 1 ] ) )
20
+ let input = try runtime. getModelValueFactory ( ) . createTensor ( value: tensorInput)
21
+
22
+ let output = try XCTUnwrap ( runtime. infer ( input: [ input, input] ) )
23
+
24
+ let tensorOutput = try output. first? . tensorValue ( ) . floatRepresentation ( )
25
+ XCTAssertEqual ( tensorOutput? . floatArray. count, 1 )
26
+ XCTAssertEqual ( tensorOutput? . shape. count, 1 )
27
+ XCTAssertEqual ( tensorOutput? . shape. first, 1 )
28
+ XCTAssertEqual ( tensorOutput? . floatArray. first, 4.0 )
29
+ }
30
+ }
You can’t perform that action at this time.
0 commit comments