Skip to content

Commit 7191604

Browse files
authored
Make Android Module thread-safe and prevent destruction during inference
Differential Revision: D72273052 Pull Request resolved: #9833
1 parent 1572381 commit 7191604

File tree

2 files changed

+112
-4
lines changed

2 files changed

+112
-4
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java

+62
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import java.io.InputStream;
2626
import java.net.URI;
2727
import java.net.URISyntaxException;
28+
import java.util.concurrent.CountDownLatch;
29+
import java.util.concurrent.atomic.AtomicInteger;
2830
import java.io.IOException;
2931
import java.io.File;
3032
import java.io.FileOutputStream;
@@ -42,6 +44,7 @@ public class ModuleInstrumentationTest {
4244
private static String FORWARD_METHOD = "forward";
4345
private static String NONE_METHOD = "none";
4446
private static int OK = 0x00;
47+
private static int INVALID_STATE = 0x2;
4548
private static int INVALID_ARGUMENT = 0x12;
4649
private static int ACCESS_FAILED = 0x22;
4750

@@ -124,4 +127,63 @@ public void testNonPteFile() throws IOException{
124127
int loadMethod = module.loadMethod(FORWARD_METHOD);
125128
assertEquals(loadMethod, INVALID_ARGUMENT);
126129
}
130+
131+
@Test
132+
public void testLoadOnDestroyedModule() throws IOException{
133+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
134+
135+
module.destroy();
136+
137+
int loadMethod = module.loadMethod(FORWARD_METHOD);
138+
assertEquals(loadMethod, INVALID_STATE);
139+
}
140+
141+
@Test
142+
public void testForwardOnDestroyedModule() throws IOException{
143+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
144+
145+
int loadMethod = module.loadMethod(FORWARD_METHOD);
146+
assertEquals(loadMethod, OK);
147+
148+
module.destroy();
149+
150+
EValue[] results = module.forward();
151+
assertEquals(0, results.length);
152+
}
153+
154+
@Test
155+
public void testForwardFromMultipleThreads() throws InterruptedException, IOException {
156+
Module module = Module.load(getTestFilePath(TEST_FILE_NAME));
157+
158+
int numThreads = 100;
159+
CountDownLatch latch = new CountDownLatch(numThreads);
160+
AtomicInteger completed = new AtomicInteger(0);
161+
162+
Runnable runnable = new Runnable() {
163+
@Override
164+
public void run() {
165+
try {
166+
latch.countDown();
167+
latch.await(5000, java.util.concurrent.TimeUnit.MILLISECONDS);
168+
EValue[] results = module.forward();
169+
assertTrue(results[0].isTensor());
170+
completed.incrementAndGet();
171+
} catch (InterruptedException e) {
172+
173+
}
174+
}
175+
};
176+
177+
Thread[] threads = new Thread[numThreads];
178+
for (int i = 0; i < numThreads; i++) {
179+
threads[i] = new Thread(runnable);
180+
threads[i].start();
181+
}
182+
183+
for (int i = 0; i < numThreads; i++) {
184+
threads[i].join();
185+
}
186+
187+
assertEquals(numThreads, completed.get());
188+
}
127189
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

+50-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
package org.pytorch.executorch;
1010

11+
import android.util.Log;
1112
import com.facebook.soloader.nativeloader.NativeLoader;
1213
import com.facebook.soloader.nativeloader.SystemDelegate;
14+
import java.util.concurrent.locks.Lock;
15+
import java.util.concurrent.locks.ReentrantLock;
1316
import org.pytorch.executorch.annotations.Experimental;
1417

1518
/**
@@ -35,6 +38,9 @@ public class Module {
3538
/** Reference to the NativePeer object of this module. */
3639
private NativePeer mNativePeer;
3740

41+
/** Lock protecting the non-thread safe methods in NativePeer. */
42+
private Lock mLock = new ReentrantLock();
43+
3844
/**
3945
* Loads a serialized ExecuTorch module from the specified path on the disk.
4046
*
@@ -72,7 +78,16 @@ public static Module load(final String modelPath) {
7278
* @return return value from the 'forward' method.
7379
*/
7480
public EValue[] forward(EValue... inputs) {
75-
return mNativePeer.forward(inputs);
81+
try {
82+
mLock.lock();
83+
if (mNativePeer == null) {
84+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
85+
return new EValue[0];
86+
}
87+
return mNativePeer.forward(inputs);
88+
} finally {
89+
mLock.unlock();
90+
}
7691
}
7792

7893
/**
@@ -83,7 +98,16 @@ public EValue[] forward(EValue... inputs) {
8398
* @return return value from the method.
8499
*/
85100
public EValue[] execute(String methodName, EValue... inputs) {
86-
return mNativePeer.execute(methodName, inputs);
101+
try {
102+
mLock.lock();
103+
if (mNativePeer == null) {
104+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
105+
return new EValue[0];
106+
}
107+
return mNativePeer.execute(methodName, inputs);
108+
} finally {
109+
mLock.unlock();
110+
}
87111
}
88112

89113
/**
@@ -96,7 +120,16 @@ public EValue[] execute(String methodName, EValue... inputs) {
96120
* @return the Error code if there was an error loading the method
97121
*/
98122
public int loadMethod(String methodName) {
99-
return mNativePeer.loadMethod(methodName);
123+
try {
124+
mLock.lock();
125+
if (mNativePeer == null) {
126+
Log.e("ExecuTorch", "Attempt to use a destroyed module");
127+
return 0x2; // InvalidState
128+
}
129+
return mNativePeer.loadMethod(methodName);
130+
} finally {
131+
mLock.unlock();
132+
}
100133
}
101134

102135
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
@@ -111,6 +144,19 @@ public String[] readLogBuffer() {
111144
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
112145
*/
113146
public void destroy() {
114-
mNativePeer.resetNative();
147+
if (mLock.tryLock()) {
148+
try {
149+
mNativePeer.resetNative();
150+
} finally {
151+
mNativePeer = null;
152+
mLock.unlock();
153+
}
154+
} else {
155+
mNativePeer = null;
156+
Log.w(
157+
"ExecuTorch",
158+
"Destroy was called while the module was in use. Resources will not be immediately"
159+
+ " released.");
160+
}
115161
}
116162
}

0 commit comments

Comments
 (0)