Skip to content

Commit e0cc2a8

Browse files
committed
improve javadoc
1 parent f53dd18 commit e0cc2a8

File tree

3 files changed

+31
-129
lines changed

3 files changed

+31
-129
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/JavaWorker.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ public class JavaWorker {
2121
private final PytorchInterface pi;
2222

2323
private boolean cancelRequested = false;
24-
24+
25+
/**
26+
* Method in the child process that is in charge of keeping the process open and calling the model load,
27+
* model inference and model closing
28+
* @param args
29+
* args of the parent process
30+
*/
2531
public static void main(String[] args) {
2632

2733
try(Scanner scanner = new Scanner(System.in)){

src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,17 @@
2929
import java.util.Arrays;
3030

3131
import ai.djl.ndarray.NDArray;
32-
import ai.djl.ndarray.types.DataType;
33-
import net.imglib2.RandomAccessibleInterval;
3432
import net.imglib2.type.numeric.integer.IntType;
3533
import net.imglib2.type.numeric.integer.LongType;
3634
import net.imglib2.type.numeric.integer.UnsignedByteType;
3735
import net.imglib2.type.numeric.real.DoubleType;
3836
import net.imglib2.type.numeric.real.FloatType;
3937

4038
/**
41-
* A {@link RandomAccessibleInterval} builder for TensorFlow {@link Tensor} objects.
42-
* Build ImgLib2 objects (backend of {@link io.bioimage.modelrunner.tensor.Tensor})
43-
* from Tensorflow 2 {@link Tensor}
39+
* A utility class that converts {@link NDArray}s into {@link SharedMemoryArray}s for
40+
* interprocessing communication
4441
*
45-
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
42+
* @author Carlos Garcia Lopez de Haro
4643
*/
4744
public final class ShmBuilder
4845
{
@@ -53,17 +50,15 @@ private ShmBuilder()
5350
{
5451
}
5552

56-
/**
57-
* Creates a {@link RandomAccessibleInterval} from a given {@link TType} tensor
58-
*
59-
* @param <T>
60-
* the possible ImgLib2 datatypes of the image
61-
* @param tensor
62-
* The {@link TType} tensor data is read from.
63-
* @throws IllegalArgumentException If the {@link TType} tensor type is not supported.
64-
* @throws IOException
65-
*/
66-
@SuppressWarnings("unchecked")
53+
/**
54+
* Create a {@link SharedMemoryArray} from a {@link NDArray}
55+
* @param tensor
56+
* the tensor to be passed into the other process through the shared memory
57+
* @param memoryName
58+
* the name of the memory region where the tensor is going to be copied
59+
* @throws IllegalArgumentException if the data type of the tensor is not supported
60+
* @throws IOException if there is any error creating the shared memory array
61+
*/
6762
public static void build(NDArray tensor, String memoryName) throws IllegalArgumentException, IOException
6863
{
6964
switch (tensor.getDataType())
@@ -83,14 +78,6 @@ public static void build(NDArray tensor, String memoryName) throws IllegalArgum
8378
}
8479
}
8580

86-
/**
87-
* Builds a {@link RandomAccessibleInterval} from a unsigned byte-typed {@link TUint8} tensor.
88-
*
89-
* @param tensor
90-
* The {@link TUint8} tensor data is read from.
91-
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
92-
* @throws IOException
93-
*/
9481
private static void buildFromTensorUByte(NDArray tensor, String memoryName) throws IOException
9582
{
9683
long[] arrayShape = tensor.getShape().getShape();
@@ -102,14 +89,6 @@ private static void buildFromTensorUByte(NDArray tensor, String memoryName) thro
10289
if (PlatformDetection.isWindows()) shma.close();
10390
}
10491

105-
/**
106-
* Builds a {@link RandomAccessibleInterval} from a unsigned int32-typed {@link TInt32} tensor.
107-
*
108-
* @param tensor
109-
* The {@link TInt32} tensor data is read from.
110-
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
111-
* @throws IOException
112-
*/
11392
private static void buildFromTensorInt(NDArray tensor, String memoryName) throws IOException
11493
{
11594
long[] arrayShape = tensor.getShape().getShape();
@@ -122,14 +101,6 @@ private static void buildFromTensorInt(NDArray tensor, String memoryName) throws
122101
if (PlatformDetection.isWindows()) shma.close();
123102
}
124103

125-
/**
126-
* Builds a {@link RandomAccessibleInterval} from a unsigned float32-typed {@link TFloat32} tensor.
127-
*
128-
* @param tensor
129-
* The {@link TFloat32} tensor data is read from.
130-
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
131-
* @throws IOException
132-
*/
133104
private static void buildFromTensorFloat(NDArray tensor, String memoryName) throws IOException
134105
{
135106
long[] arrayShape = tensor.getShape().getShape();
@@ -142,14 +113,6 @@ private static void buildFromTensorFloat(NDArray tensor, String memoryName) thro
142113
if (PlatformDetection.isWindows()) shma.close();
143114
}
144115

145-
/**
146-
* Builds a {@link RandomAccessibleInterval} from a unsigned float64-typed {@link TFloat64} tensor.
147-
*
148-
* @param tensor
149-
* The {@link TFloat64} tensor data is read from.
150-
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
151-
* @throws IOException
152-
*/
153116
private static void buildFromTensorDouble(NDArray tensor, String memoryName) throws IOException
154117
{
155118
long[] arrayShape = tensor.getShape().getShape();
@@ -162,14 +125,6 @@ private static void buildFromTensorDouble(NDArray tensor, String memoryName) thr
162125
if (PlatformDetection.isWindows()) shma.close();
163126
}
164127

165-
/**
166-
* Builds a {@link RandomAccessibleInterval} from a unsigned int64-typed {@link TInt64} tensor.
167-
*
168-
* @param tensor
169-
* The {@link TInt64} tensor data is read from.
170-
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
171-
* @throws IOException
172-
*/
173128
private static void buildFromTensorLong(NDArray tensor, String memoryName) throws IOException
174129
{
175130
long[] arrayShape = tensor.getShape().getShape();

src/main/java/io/bioimage/modelrunner/pytorch/shm/TensorBuilder.java

Lines changed: 12 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@
2323

2424
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
2525
import io.bioimage.modelrunner.utils.CommonUtils;
26-
import net.imglib2.RandomAccessibleInterval;
27-
import net.imglib2.img.Img;
28-
import net.imglib2.type.numeric.integer.IntType;
29-
import net.imglib2.type.numeric.integer.LongType;
30-
import net.imglib2.type.numeric.integer.UnsignedByteType;
31-
import net.imglib2.type.numeric.real.DoubleType;
32-
import net.imglib2.type.numeric.real.FloatType;
3326
import net.imglib2.util.Cast;
3427

3528
import java.nio.ByteBuffer;
@@ -44,10 +37,9 @@
4437
import ai.djl.ndarray.types.Shape;
4538

4639
/**
47-
* A TensorFlow 2 {@link Tensor} builder from {@link Img} and
48-
* {@link io.bioimage.modelrunner.tensor.Tensor} objects.
40+
* Utility class to build Pytorch tensors from shm segments using {@link SharedMemoryArray}
4941
*
50-
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
42+
* @author Carlos Garcia Lopez de Haro
5143
*/
5244
public final class TensorBuilder {
5345

@@ -57,16 +49,15 @@ public final class TensorBuilder {
5749
private TensorBuilder() {}
5850

5951
/**
60-
* Creates {@link TType} instance with the same size and information as the
61-
* given {@link RandomAccessibleInterval}.
52+
* Creates {@link NDArray} instance from a {@link SharedMemoryArray}
6253
*
63-
* @param <T>
64-
* the ImgLib2 data types the {@link RandomAccessibleInterval} can be
6554
* @param array
66-
* the {@link RandomAccessibleInterval} that is going to be converted into
67-
* a {@link TType} tensor
68-
* @return a {@link TType} tensor
69-
* @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval}
55+
* the {@link SharedMemoryArray} that is going to be converted into
56+
* a {@link NDArray} tensor
57+
* @param manager
58+
* DJL manager that controls the creation and destruction of {@link NDArrays}
59+
* @return the Pytorch {@link NDArray} as the one stored in the shared memory segment
60+
* @throws IllegalArgumentException if the type of the {@link SharedMemoryArray}
7061
* is not supported
7162
*/
7263
public static NDArray build(SharedMemoryArray array, NDManager manager) throws IllegalArgumentException
@@ -92,17 +83,7 @@ else if (array.getOriginalDataType().equals("int64")) {
9283
}
9384
}
9485

95-
/**
96-
* Creates a {@link TType} tensor of type {@link TUint8} from an
97-
* {@link RandomAccessibleInterval} of type {@link UnsignedByteType}
98-
*
99-
* @param tensor
100-
* The {@link RandomAccessibleInterval} to fill the tensor with.
101-
* @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data.
102-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
103-
* not compatible
104-
*/
105-
public static NDArray buildUByte(SharedMemoryArray tensor, NDManager manager)
86+
private static NDArray buildUByte(SharedMemoryArray tensor, NDManager manager)
10687
throws IllegalArgumentException
10788
{
10889
long[] ogShape = tensor.getOriginalShape();
@@ -116,17 +97,7 @@ public static NDArray buildUByte(SharedMemoryArray tensor, NDManager manager)
11697
return ndarray;
11798
}
11899

119-
/**
120-
* Creates a {@link TInt32} tensor of type {@link TInt32} from an
121-
* {@link RandomAccessibleInterval} of type {@link IntType}
122-
*
123-
* @param tensor
124-
* The {@link RandomAccessibleInterval} to fill the tensor with.
125-
* @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data.
126-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
127-
* not compatible
128-
*/
129-
public static NDArray buildInt(SharedMemoryArray tensor, NDManager manager)
100+
private static NDArray buildInt(SharedMemoryArray tensor, NDManager manager)
130101
throws IllegalArgumentException
131102
{
132103
long[] ogShape = tensor.getOriginalShape();
@@ -143,16 +114,6 @@ public static NDArray buildInt(SharedMemoryArray tensor, NDManager manager)
143114
return ndarray;
144115
}
145116

146-
/**
147-
* Creates a {@link TInt64} tensor of type {@link TInt64} from an
148-
* {@link RandomAccessibleInterval} of type {@link LongType}
149-
*
150-
* @param tensor
151-
* The {@link RandomAccessibleInterval} to fill the tensor with.
152-
* @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data.
153-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
154-
* not compatible
155-
*/
156117
private static NDArray buildLong(SharedMemoryArray tensor, NDManager manager)
157118
throws IllegalArgumentException
158119
{
@@ -170,17 +131,7 @@ private static NDArray buildLong(SharedMemoryArray tensor, NDManager manager)
170131
return ndarray;
171132
}
172133

173-
/**
174-
* Creates a {@link TFloat32} tensor of type {@link TFloat32} from an
175-
* {@link RandomAccessibleInterval} of type {@link FloatType}
176-
*
177-
* @param tensor
178-
* The {@link RandomAccessibleInterval} to fill the tensor with.
179-
* @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data.
180-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
181-
* not compatible
182-
*/
183-
public static NDArray buildFloat(SharedMemoryArray tensor, NDManager manager)
134+
private static NDArray buildFloat(SharedMemoryArray tensor, NDManager manager)
184135
throws IllegalArgumentException
185136
{
186137
long[] ogShape = tensor.getOriginalShape();
@@ -197,16 +148,6 @@ public static NDArray buildFloat(SharedMemoryArray tensor, NDManager manager)
197148
return ndarray;
198149
}
199150

200-
/**
201-
* Creates a {@link TFloat64} tensor of type {@link TFloat64} from an
202-
* {@link RandomAccessibleInterval} of type {@link DoubleType}
203-
*
204-
* @param tensor
205-
* The {@link RandomAccessibleInterval} to fill the tensor with.
206-
* @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data.
207-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
208-
* not compatible
209-
*/
210151
private static NDArray buildDouble(SharedMemoryArray tensor, NDManager manager)
211152
throws IllegalArgumentException
212153
{

0 commit comments

Comments
 (0)