Skip to content

Commit 4eb89ff

Browse files
committed
finish adapting to persistent memory
1 parent ee36ef0 commit 4eb89ff

File tree

3 files changed

+42
-99
lines changed

3 files changed

+42
-99
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import io.bioimage.modelrunner.system.PlatformDetection;
3838
import io.bioimage.modelrunner.tensor.Tensor;
3939
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
40+
import io.bioimage.modelrunner.tensorflow.v1.shm.ShmBuilder;
4041
import io.bioimage.modelrunner.tensorflow.v1.tensor.ImgLib2Builder;
4142
import io.bioimage.modelrunner.tensorflow.v1.tensor.TensorBuilder;
4243
import io.bioimage.modelrunner.utils.CommonUtils;
@@ -45,28 +46,17 @@
4546
import net.imglib2.RandomAccessibleInterval;
4647
import net.imglib2.type.NativeType;
4748
import net.imglib2.type.numeric.RealType;
48-
import net.imglib2.type.numeric.real.FloatType;
4949
import net.imglib2.util.Cast;
5050
import net.imglib2.util.Util;
5151

52-
import java.io.BufferedReader;
5352
import java.io.File;
5453
import java.io.IOException;
55-
import java.io.InputStreamReader;
56-
import java.io.RandomAccessFile;
5754
import java.io.UnsupportedEncodingException;
5855
import java.net.URISyntaxException;
5956
import java.net.URL;
6057
import java.net.URLDecoder;
61-
import java.nio.ByteBuffer;
62-
import java.nio.MappedByteBuffer;
63-
import java.nio.channels.FileChannel;
6458
import java.nio.charset.StandardCharsets;
65-
import java.nio.file.Files;
66-
import java.nio.file.Paths;
6759
import java.security.ProtectionDomain;
68-
import java.time.LocalDateTime;
69-
import java.time.format.DateTimeFormatter;
7060
import java.util.ArrayList;
7161
import java.util.HashMap;
7262
import java.util.LinkedHashMap;
@@ -332,7 +322,7 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
332322
for (String ee : inputs) {
333323
Map<String, Object> decoded = Types.decode(ee);
334324
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
335-
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
325+
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v1.shm.TensorBuilder.build(shma);
336326
if (PlatformDetection.isWindows()) shma.close();
337327
inTensors.add(inT);
338328
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);

src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java

Lines changed: 28 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@
2929
import java.util.Arrays;
3030

3131
import org.tensorflow.Tensor;
32-
import org.tensorflow.types.TFloat32;
33-
import org.tensorflow.types.TFloat64;
34-
import org.tensorflow.types.TInt32;
35-
import org.tensorflow.types.TInt64;
36-
import org.tensorflow.types.TUint8;
37-
import org.tensorflow.types.family.TType;
32+
import org.tensorflow.types.UInt8;
3833

3934
import net.imglib2.RandomAccessibleInterval;
4035
import net.imglib2.type.numeric.integer.IntType;
@@ -70,20 +65,20 @@ private ShmBuilder()
7065
* @throws IOException
7166
*/
7267
@SuppressWarnings("unchecked")
73-
public static void build(Tensor<? extends TType> tensor, String memoryName) throws IllegalArgumentException, IOException
68+
public static void build(Tensor<?> tensor, String memoryName) throws IllegalArgumentException, IOException
7469
{
75-
switch (tensor.dataType().name())
70+
switch (tensor.dataType())
7671
{
77-
case TUint8.NAME:
78-
buildFromTensorUByte((Tensor<TUint8>) tensor, memoryName);
79-
case TInt32.NAME:
80-
buildFromTensorInt((Tensor<TInt32>) tensor, memoryName);
81-
case TFloat32.NAME:
82-
buildFromTensorFloat((Tensor<TFloat32>) tensor, memoryName);
83-
case TFloat64.NAME:
84-
buildFromTensorDouble((Tensor<TFloat64>) tensor, memoryName);
85-
case TInt64.NAME:
86-
buildFromTensorLong((Tensor<TInt64>) tensor, memoryName);
72+
case UINT8:
73+
buildFromTensorUByte((Tensor<UInt8>) tensor, memoryName);
74+
case INT32:
75+
buildFromTensorInt((Tensor<Integer>) tensor, memoryName);
76+
case FLOAT:
77+
buildFromTensorFloat((Tensor<Float>) tensor, memoryName);
78+
case DOUBLE:
79+
buildFromTensorDouble((Tensor<Double>) tensor, memoryName);
80+
case INT64:
81+
buildFromTensorLong((Tensor<Long>) tensor, memoryName);
8782
default:
8883
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
8984
}
@@ -97,20 +92,15 @@ public static void build(Tensor<? extends TType> tensor, String memoryName) thr
9792
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
9893
* @throws IOException
9994
*/
100-
private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryName) throws IOException
95+
private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName) throws IOException
10196
{
102-
long[] arrayShape = tensor.shape().asArray();
97+
long[] arrayShape = tensor.shape();
10398
if (CommonUtils.int32Overflows(arrayShape, 1))
10499
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
105100
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
106101
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
107102
ByteBuffer buff = shma.getDataBuffer();
108-
int totalSize = 1;
109-
for (long i : arrayShape) {totalSize *= i;}
110-
byte[] flatArr = new byte[buff.capacity()];
111-
buff.get(flatArr);
112-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
113-
shma.setBuffer(ByteBuffer.wrap(flatArr));
103+
tensor.writeTo(buff);
114104
if (PlatformDetection.isWindows()) shma.close();
115105
}
116106

@@ -122,21 +112,16 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
122112
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
123113
* @throws IOException
124114
*/
125-
private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName) throws IOException
115+
private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName) throws IOException
126116
{
127-
long[] arrayShape = tensor.shape().asArray();
117+
long[] arrayShape = tensor.shape();
128118
if (CommonUtils.int32Overflows(arrayShape, 4))
129119
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
130120
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
131121

132122
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
133123
ByteBuffer buff = shma.getDataBuffer();
134-
int totalSize = 4;
135-
for (long i : arrayShape) {totalSize *= i;}
136-
byte[] flatArr = new byte[buff.capacity()];
137-
buff.get(flatArr);
138-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
139-
shma.setBuffer(ByteBuffer.wrap(flatArr));
124+
tensor.writeTo(buff);
140125
if (PlatformDetection.isWindows()) shma.close();
141126
}
142127

@@ -148,21 +133,16 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
148133
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
149134
* @throws IOException
150135
*/
151-
private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryName) throws IOException
136+
private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName) throws IOException
152137
{
153-
long[] arrayShape = tensor.shape().asArray();
138+
long[] arrayShape = tensor.shape();
154139
if (CommonUtils.int32Overflows(arrayShape, 4))
155140
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
156141
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
157142

158143
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
159144
ByteBuffer buff = shma.getDataBuffer();
160-
int totalSize = 4;
161-
for (long i : arrayShape) {totalSize *= i;}
162-
byte[] flatArr = new byte[buff.capacity()];
163-
buff.get(flatArr);
164-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
165-
shma.setBuffer(ByteBuffer.wrap(flatArr));
145+
tensor.writeTo(buff);
166146
if (PlatformDetection.isWindows()) shma.close();
167147
}
168148

@@ -174,21 +154,16 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
174154
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
175155
* @throws IOException
176156
*/
177-
private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memoryName) throws IOException
157+
private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryName) throws IOException
178158
{
179-
long[] arrayShape = tensor.shape().asArray();
159+
long[] arrayShape = tensor.shape();
180160
if (CommonUtils.int32Overflows(arrayShape, 8))
181161
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
182162
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
183163

184164
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
185165
ByteBuffer buff = shma.getDataBuffer();
186-
int totalSize = 8;
187-
for (long i : arrayShape) {totalSize *= i;}
188-
byte[] flatArr = new byte[buff.capacity()];
189-
buff.get(flatArr);
190-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
191-
shma.setBuffer(ByteBuffer.wrap(flatArr));
166+
tensor.writeTo(buff);
192167
if (PlatformDetection.isWindows()) shma.close();
193168
}
194169

@@ -200,22 +175,17 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
200175
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
201176
* @throws IOException
202177
*/
203-
private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName) throws IOException
178+
private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName) throws IOException
204179
{
205-
long[] arrayShape = tensor.shape().asArray();
180+
long[] arrayShape = tensor.shape();
206181
if (CommonUtils.int32Overflows(arrayShape, 8))
207182
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
208183
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
209184

210185

211186
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
212187
ByteBuffer buff = shma.getDataBuffer();
213-
int totalSize = 8;
214-
for (long i : arrayShape) {totalSize *= i;}
215-
byte[] flatArr = new byte[buff.capacity()];
216-
buff.get(flatArr);
217-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
218-
shma.setBuffer(ByteBuffer.wrap(flatArr));
188+
tensor.writeTo(buff);
219189
if (PlatformDetection.isWindows()) shma.close();
220190
}
221191
}

src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/TensorBuilder.java

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,7 @@
4040
import java.util.Arrays;
4141

4242
import org.tensorflow.Tensor;
43-
import org.tensorflow.ndarray.Shape;
44-
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
45-
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
46-
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
47-
import org.tensorflow.ndarray.buffer.IntDataBuffer;
48-
import org.tensorflow.ndarray.buffer.LongDataBuffer;
49-
import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory;
50-
import org.tensorflow.types.TFloat32;
51-
import org.tensorflow.types.TFloat64;
52-
import org.tensorflow.types.TInt32;
53-
import org.tensorflow.types.TInt64;
54-
import org.tensorflow.types.TUint8;
55-
import org.tensorflow.types.family.TType;
43+
import org.tensorflow.types.UInt8;
5644

5745
/**
5846
* A TensorFlow 2 {@link Tensor} builder from {@link Img} and
@@ -80,7 +68,7 @@ private TensorBuilder() {}
8068
* @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval}
8169
* is not supported
8270
*/
83-
public static Tensor<? extends TType> build(SharedMemoryArray array) throws IllegalArgumentException
71+
public static Tensor<?> build(SharedMemoryArray array) throws IllegalArgumentException
8472
{
8573
// Create an Icy sequence of the same type of the tensor
8674
if (array.getOriginalDataType().equals("uint8")) {
@@ -113,7 +101,7 @@ else if (array.getOriginalDataType().equals("int64")) {
113101
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
114102
* not compatible
115103
*/
116-
public static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
104+
public static Tensor<UInt8> buildUByte(SharedMemoryArray tensor)
117105
throws IllegalArgumentException
118106
{
119107
long[] ogShape = tensor.getOriginalShape();
@@ -123,8 +111,7 @@ public static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
123111
if (!tensor.isNumpyFormat())
124112
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
125113
ByteBuffer buff = tensor.getDataBufferNoHeader();
126-
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false);
127-
Tensor<TUint8> ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer);
114+
Tensor<UInt8> ndarray = Tensor.create(UInt8.class, ogShape, buff);
128115
return ndarray;
129116
}
130117

@@ -138,7 +125,7 @@ public static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
138125
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
139126
* not compatible
140127
*/
141-
public static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
128+
public static Tensor<Integer> buildInt(SharedMemoryArray tensor)
142129
throws IllegalArgumentException
143130
{
144131
long[] ogShape = tensor.getOriginalShape();
@@ -151,8 +138,7 @@ public static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
151138
IntBuffer intBuff = buff.asIntBuffer();
152139
int[] intArray = new int[intBuff.capacity()];
153140
intBuff.get(intArray);
154-
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false);
155-
Tensor<TInt32> ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer);
141+
Tensor<Integer> ndarray = Tensor.create(ogShape, intBuff);
156142
return ndarray;
157143
}
158144

@@ -166,7 +152,7 @@ public static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
166152
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
167153
* not compatible
168154
*/
169-
private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
155+
private static Tensor<Long> buildLong(SharedMemoryArray tensor)
170156
throws IllegalArgumentException
171157
{
172158
long[] ogShape = tensor.getOriginalShape();
@@ -179,8 +165,7 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
179165
LongBuffer longBuff = buff.asLongBuffer();
180166
long[] longArray = new long[longBuff.capacity()];
181167
longBuff.get(longArray);
182-
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false);
183-
Tensor<TInt64> ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer);
168+
Tensor<Long> ndarray = Tensor.create(ogShape, longBuff);
184169
return ndarray;
185170
}
186171

@@ -194,7 +179,7 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
194179
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
195180
* not compatible
196181
*/
197-
public static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
182+
public static Tensor<Float> buildFloat(SharedMemoryArray tensor)
198183
throws IllegalArgumentException
199184
{
200185
long[] ogShape = tensor.getOriginalShape();
@@ -207,8 +192,7 @@ public static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
207192
FloatBuffer floatBuff = buff.asFloatBuffer();
208193
float[] floatArray = new float[floatBuff.capacity()];
209194
floatBuff.get(floatArray);
210-
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false);
211-
Tensor<TFloat32> ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
195+
Tensor<Float> ndarray = Tensor.create(ogShape, floatBuff);
212196
return ndarray;
213197
}
214198

@@ -222,7 +206,7 @@ public static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
222206
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
223207
* not compatible
224208
*/
225-
private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
209+
private static Tensor<Double> buildDouble(SharedMemoryArray tensor)
226210
throws IllegalArgumentException
227211
{
228212
long[] ogShape = tensor.getOriginalShape();
@@ -235,8 +219,7 @@ private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
235219
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
236220
double[] doubleArray = new double[doubleBuff.capacity()];
237221
doubleBuff.get(doubleArray);
238-
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false);
239-
Tensor<TFloat64> ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
222+
Tensor<Double> ndarray = Tensor.create(ogShape, doubleBuff);
240223
return ndarray;
241224
}
242225
}

0 commit comments

Comments
 (0)