Skip to content

Commit 6825604

Browse files
committed
correct shm tensor conversion and handle crashes
1 parent 05d1675 commit 6825604

File tree

3 files changed

+52
-21
lines changed

3 files changed

+52
-21
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/Tensorflow2Interface.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,11 @@ private void launchModelLoadOnProcess() throws IOException, InterruptedException
198198
throw new RuntimeException();
199199
else if (task.status == TaskStatus.FAILED)
200200
throw new RuntimeException();
201-
else if (task.status == TaskStatus.CRASHED)
201+
else if (task.status == TaskStatus.CRASHED) {
202+
this.runner.close();
203+
runner = null;
202204
throw new RuntimeException();
205+
}
203206
}
204207

205208
/**
@@ -360,8 +363,11 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
360363
throw new RuntimeException();
361364
else if (task.status == TaskStatus.FAILED)
362365
throw new RuntimeException();
363-
else if (task.status == TaskStatus.CRASHED)
366+
else if (task.status == TaskStatus.CRASHED) {
367+
this.runner.close();
368+
runner = null;
364369
throw new RuntimeException();
370+
}
365371
for (int i = 0; i < outputTensors.size(); i ++) {
366372
String name = (String) Types.decode(encOuts.get(i)).get(MEM_NAME_KEY);
367373
SharedMemoryArray shm = shmaOutputList.stream()
@@ -491,8 +497,11 @@ public void closeModel() {
491497
throw new RuntimeException();
492498
else if (task.status == TaskStatus.FAILED)
493499
throw new RuntimeException();
494-
else if (task.status == TaskStatus.CRASHED)
500+
else if (task.status == TaskStatus.CRASHED) {
501+
this.runner.close();
502+
runner = null;
495503
throw new RuntimeException();
504+
}
496505
this.runner.close();
497506
this.runner = null;
498507
return;

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/shm/ShmBuilder.java

+20-5
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
101101
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
102102
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
103103
ByteBuffer buff = shma.getDataBufferNoHeader();
104-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
104+
byte[] flat = new byte[buff.capacity()];
105+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
106+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
107+
buff = buff2;
105108
if (PlatformDetection.isWindows()) shma.close();
106109
}
107110

@@ -114,7 +117,10 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws
114117

115118
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
116119
ByteBuffer buff = shma.getDataBufferNoHeader();
117-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
120+
byte[] flat = new byte[buff.capacity()];
121+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
122+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
123+
buff = buff2;
118124
if (PlatformDetection.isWindows()) shma.close();
119125
}
120126

@@ -127,7 +133,10 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
127133

128134
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
129135
ByteBuffer buff = shma.getDataBufferNoHeader();
130-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
136+
byte[] flat = new byte[buff.capacity()];
137+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
138+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
139+
buff = buff2;
131140
if (PlatformDetection.isWindows()) shma.close();
132141
}
133142

@@ -140,7 +149,10 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th
140149

141150
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
142151
ByteBuffer buff = shma.getDataBufferNoHeader();
143-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
152+
byte[] flat = new byte[buff.capacity()];
153+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
154+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
155+
buff = buff2;
144156
if (PlatformDetection.isWindows()) shma.close();
145157
}
146158

@@ -154,7 +166,10 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws
154166

155167
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
156168
ByteBuffer buff = shma.getDataBufferNoHeader();
157-
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
169+
byte[] flat = new byte[buff.capacity()];
170+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
171+
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
172+
buff = buff2;
158173
if (PlatformDetection.isWindows()) shma.close();
159174
}
160175
}

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/shm/TensorBuilder.java

+20-13
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626
import net.imglib2.util.Cast;
2727

2828
import java.nio.ByteBuffer;
29-
import java.nio.DoubleBuffer;
30-
import java.nio.FloatBuffer;
31-
import java.nio.IntBuffer;
32-
import java.nio.LongBuffer;
3329
import java.util.Arrays;
3430

3531
import org.tensorflow.Tensor;
@@ -102,7 +98,10 @@ private static TUint8 buildUByte(SharedMemoryArray tensor)
10298
if (!tensor.isNumpyFormat())
10399
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
104100
ByteBuffer buff = tensor.getDataBufferNoHeader();
105-
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false);
101+
byte[] flat = new byte[buff.capacity()];
102+
buff.get(flat);
103+
buff.rewind();
104+
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
106105
TUint8 ndarray = Tensor.of(TUint8.class, Shape.of(ogShape), dataBuffer);
107106
return ndarray;
108107
}
@@ -117,8 +116,10 @@ private static TInt32 buildInt(SharedMemoryArray tensor)
117116
if (!tensor.isNumpyFormat())
118117
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
119118
ByteBuffer buff = tensor.getDataBufferNoHeader();
120-
IntBuffer intBuff = buff.asIntBuffer();
121-
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
119+
int[] flat = new int[buff.capacity() / 4];
120+
buff.asIntBuffer().get(flat);
121+
buff.rewind();
122+
IntDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
122123
TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape),
123124
dataBuffer);
124125
return ndarray;
@@ -134,8 +135,10 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
134135
if (!tensor.isNumpyFormat())
135136
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
136137
ByteBuffer buff = tensor.getDataBufferNoHeader();
137-
LongBuffer longBuff = buff.asLongBuffer();
138-
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
138+
long[] flat = new long[buff.capacity() / 8];
139+
buff.asLongBuffer().get(flat);
140+
buff.rewind();
141+
LongDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
139142
TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape),
140143
dataBuffer);
141144
return ndarray;
@@ -151,8 +154,10 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
151154
if (!tensor.isNumpyFormat())
152155
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
153156
ByteBuffer buff = tensor.getDataBufferNoHeader();
154-
FloatBuffer floatBuff = buff.asFloatBuffer();
155-
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
157+
float[] flat = new float[buff.capacity() / 4];
158+
buff.asFloatBuffer().get(flat);
159+
buff.rewind();
160+
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
156161
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
157162
return ndarray;
158163
}
@@ -167,8 +172,10 @@ private static TFloat64 buildDouble(SharedMemoryArray tensor)
167172
if (!tensor.isNumpyFormat())
168173
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
169174
ByteBuffer buff = tensor.getDataBufferNoHeader();
170-
DoubleBuffer floatBuff = buff.asDoubleBuffer();
171-
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
175+
double[] flat = new double[buff.capacity() / 8];
176+
buff.asDoubleBuffer().get(flat);
177+
buff.rewind();
178+
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
172179
TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
173180
return ndarray;
174181
}

0 commit comments

Comments
 (0)