23
23
24
24
import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
25
25
import io .bioimage .modelrunner .utils .CommonUtils ;
26
- import net .imglib2 .RandomAccessibleInterval ;
27
- import net .imglib2 .img .Img ;
28
- import net .imglib2 .type .numeric .integer .IntType ;
29
- import net .imglib2 .type .numeric .integer .LongType ;
30
- import net .imglib2 .type .numeric .integer .UnsignedByteType ;
31
- import net .imglib2 .type .numeric .real .DoubleType ;
32
- import net .imglib2 .type .numeric .real .FloatType ;
33
26
import net .imglib2 .util .Cast ;
34
27
35
28
import java .nio .ByteBuffer ;
44
37
import ai .djl .ndarray .types .Shape ;
45
38
46
39
/**
47
- * A TensorFlow 2 {@link Tensor} builder from {@link Img} and
48
- * {@link io.bioimage.modelrunner.tensor.Tensor} objects.
40
+ * Utility class to build Pytorch tensors from shm segments using {@link SharedMemoryArray}
49
41
*
50
- * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
42
+ * @author Carlos Garcia Lopez de Haro
51
43
*/
52
44
public final class TensorBuilder {
53
45
@@ -57,16 +49,15 @@ public final class TensorBuilder {
57
49
private TensorBuilder () {}
58
50
59
51
/**
60
- * Creates {@link TType} instance with the same size and information as the
61
- * given {@link RandomAccessibleInterval}.
52
+ * Creates {@link NDArray} instance from a {@link SharedMemoryArray}
62
53
*
63
- * @param <T>
64
- * the ImgLib2 data types the {@link RandomAccessibleInterval} can be
65
54
* @param array
66
- * the {@link RandomAccessibleInterval} that is going to be converted into
67
- * a {@link TType} tensor
68
- * @return a {@link TType} tensor
69
- * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval}
55
+ * the {@link SharedMemoryArray} that is going to be converted into
56
+ * a {@link NDArray} tensor
57
+ * @param manager
58
+ * DJL manager that controls the creation and destruction of {@link NDArrays}
59
+ * @return the Pytorch {@link NDArray} as the one stored in the shared memory segment
60
+ * @throws IllegalArgumentException if the type of the {@link SharedMemoryArray}
70
61
* is not supported
71
62
*/
72
63
public static NDArray build (SharedMemoryArray array , NDManager manager ) throws IllegalArgumentException
@@ -92,17 +83,7 @@ else if (array.getOriginalDataType().equals("int64")) {
92
83
}
93
84
}
94
85
95
- /**
96
- * Creates a {@link TType} tensor of type {@link TUint8} from an
97
- * {@link RandomAccessibleInterval} of type {@link UnsignedByteType}
98
- *
99
- * @param tensor
100
- * The {@link RandomAccessibleInterval} to fill the tensor with.
101
- * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data.
102
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
103
- * not compatible
104
- */
105
- public static NDArray buildUByte (SharedMemoryArray tensor , NDManager manager )
86
+ private static NDArray buildUByte (SharedMemoryArray tensor , NDManager manager )
106
87
throws IllegalArgumentException
107
88
{
108
89
long [] ogShape = tensor .getOriginalShape ();
@@ -116,17 +97,7 @@ public static NDArray buildUByte(SharedMemoryArray tensor, NDManager manager)
116
97
return ndarray ;
117
98
}
118
99
119
- /**
120
- * Creates a {@link TInt32} tensor of type {@link TInt32} from an
121
- * {@link RandomAccessibleInterval} of type {@link IntType}
122
- *
123
- * @param tensor
124
- * The {@link RandomAccessibleInterval} to fill the tensor with.
125
- * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data.
126
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
127
- * not compatible
128
- */
129
- public static NDArray buildInt (SharedMemoryArray tensor , NDManager manager )
100
+ private static NDArray buildInt (SharedMemoryArray tensor , NDManager manager )
130
101
throws IllegalArgumentException
131
102
{
132
103
long [] ogShape = tensor .getOriginalShape ();
@@ -143,16 +114,6 @@ public static NDArray buildInt(SharedMemoryArray tensor, NDManager manager)
143
114
return ndarray ;
144
115
}
145
116
146
- /**
147
- * Creates a {@link TInt64} tensor of type {@link TInt64} from an
148
- * {@link RandomAccessibleInterval} of type {@link LongType}
149
- *
150
- * @param tensor
151
- * The {@link RandomAccessibleInterval} to fill the tensor with.
152
- * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data.
153
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
154
- * not compatible
155
- */
156
117
private static NDArray buildLong (SharedMemoryArray tensor , NDManager manager )
157
118
throws IllegalArgumentException
158
119
{
@@ -170,17 +131,7 @@ private static NDArray buildLong(SharedMemoryArray tensor, NDManager manager)
170
131
return ndarray ;
171
132
}
172
133
173
- /**
174
- * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an
175
- * {@link RandomAccessibleInterval} of type {@link FloatType}
176
- *
177
- * @param tensor
178
- * The {@link RandomAccessibleInterval} to fill the tensor with.
179
- * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data.
180
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
181
- * not compatible
182
- */
183
- public static NDArray buildFloat (SharedMemoryArray tensor , NDManager manager )
134
+ private static NDArray buildFloat (SharedMemoryArray tensor , NDManager manager )
184
135
throws IllegalArgumentException
185
136
{
186
137
long [] ogShape = tensor .getOriginalShape ();
@@ -197,16 +148,6 @@ public static NDArray buildFloat(SharedMemoryArray tensor, NDManager manager)
197
148
return ndarray ;
198
149
}
199
150
200
- /**
201
- * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an
202
- * {@link RandomAccessibleInterval} of type {@link DoubleType}
203
- *
204
- * @param tensor
205
- * The {@link RandomAccessibleInterval} to fill the tensor with.
206
- * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data.
207
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
208
- * not compatible
209
- */
210
151
private static NDArray buildDouble (SharedMemoryArray tensor , NDManager manager )
211
152
throws IllegalArgumentException
212
153
{
0 commit comments