23
23
24
24
import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
25
25
import io .bioimage .modelrunner .utils .CommonUtils ;
26
- import net .imglib2 .RandomAccessibleInterval ;
27
- import net .imglib2 .img .Img ;
28
- import net .imglib2 .type .numeric .integer .IntType ;
29
- import net .imglib2 .type .numeric .integer .LongType ;
30
- import net .imglib2 .type .numeric .integer .UnsignedByteType ;
31
- import net .imglib2 .type .numeric .real .DoubleType ;
32
- import net .imglib2 .type .numeric .real .FloatType ;
33
26
import net .imglib2 .util .Cast ;
34
27
35
28
import java .nio .ByteBuffer ;
55
48
import org .tensorflow .types .family .TType ;
56
49
57
50
/**
58
- * A TensorFlow 2 {@link Tensor} builder from {@link Img} and
59
- * {@link io.bioimage.modelrunner.tensor.Tensor} objects.
51
+ * Utility class to build Tensorflow tensors from shm segments using {@link SharedMemoryArray}
60
52
*
61
- * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
53
+ * @author Carlos Garcia Lopez de Haro
62
54
*/
63
55
public final class TensorBuilder {
64
56
@@ -68,16 +60,13 @@ public final class TensorBuilder {
68
60
private TensorBuilder () {}
69
61
70
62
/**
71
- * Creates {@link TType} instance with the same size and information as the
72
- * given {@link RandomAccessibleInterval}.
63
+ * Creates {@link TType} instance from a {@link SharedMemoryArray}
73
64
*
74
- * @param <T>
75
- * the ImgLib2 data types the {@link RandomAccessibleInterval} can be
76
65
* @param array
77
- * the {@link RandomAccessibleInterval } that is going to be converted into
66
+ * the {@link SharedMemoryArray } that is going to be converted into
78
67
* a {@link TType} tensor
79
- * @return a {@link TType} tensor
80
- * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval }
68
+ * @return the Tensorflow {@link TType} as the one stored in the shared memory segment
69
+ * @throws IllegalArgumentException if the type of the {@link SharedMemoryArray }
81
70
* is not supported
82
71
*/
83
72
public static TType build (SharedMemoryArray array ) throws IllegalArgumentException
@@ -103,17 +92,7 @@ else if (array.getOriginalDataType().equals("int64")) {
103
92
}
104
93
}
105
94
106
- /**
107
- * Creates a {@link TType} tensor of type {@link TUint8} from an
108
- * {@link RandomAccessibleInterval} of type {@link UnsignedByteType}
109
- *
110
- * @param tensor
111
- * The {@link RandomAccessibleInterval} to fill the tensor with.
112
- * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data.
113
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
114
- * not compatible
115
- */
116
- public static TUint8 buildUByte (SharedMemoryArray tensor )
95
+ private static TUint8 buildUByte (SharedMemoryArray tensor )
117
96
throws IllegalArgumentException
118
97
{
119
98
long [] ogShape = tensor .getOriginalShape ();
@@ -128,17 +107,7 @@ public static TUint8 buildUByte(SharedMemoryArray tensor)
128
107
return ndarray ;
129
108
}
130
109
131
- /**
132
- * Creates a {@link TInt32} tensor of type {@link TInt32} from an
133
- * {@link RandomAccessibleInterval} of type {@link IntType}
134
- *
135
- * @param tensor
136
- * The {@link RandomAccessibleInterval} to fill the tensor with.
137
- * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data.
138
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
139
- * not compatible
140
- */
141
- public static TInt32 buildInt (SharedMemoryArray tensor )
110
+ private static TInt32 buildInt (SharedMemoryArray tensor )
142
111
throws IllegalArgumentException
143
112
{
144
113
long [] ogShape = tensor .getOriginalShape ();
@@ -157,16 +126,6 @@ public static TInt32 buildInt(SharedMemoryArray tensor)
157
126
return ndarray ;
158
127
}
159
128
160
- /**
161
- * Creates a {@link TInt64} tensor of type {@link TInt64} from an
162
- * {@link RandomAccessibleInterval} of type {@link LongType}
163
- *
164
- * @param tensor
165
- * The {@link RandomAccessibleInterval} to fill the tensor with.
166
- * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data.
167
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
168
- * not compatible
169
- */
170
129
private static TInt64 buildLong (SharedMemoryArray tensor )
171
130
throws IllegalArgumentException
172
131
{
@@ -186,17 +145,7 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
186
145
return ndarray ;
187
146
}
188
147
189
- /**
190
- * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an
191
- * {@link RandomAccessibleInterval} of type {@link FloatType}
192
- *
193
- * @param tensor
194
- * The {@link RandomAccessibleInterval} to fill the tensor with.
195
- * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data.
196
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
197
- * not compatible
198
- */
199
- public static TFloat32 buildFloat (SharedMemoryArray tensor )
148
+ private static TFloat32 buildFloat (SharedMemoryArray tensor )
200
149
throws IllegalArgumentException
201
150
{
202
151
long [] ogShape = tensor .getOriginalShape ();
@@ -214,16 +163,6 @@ public static TFloat32 buildFloat(SharedMemoryArray tensor)
214
163
return ndarray ;
215
164
}
216
165
217
- /**
218
- * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an
219
- * {@link RandomAccessibleInterval} of type {@link DoubleType}
220
- *
221
- * @param tensor
222
- * The {@link RandomAccessibleInterval} to fill the tensor with.
223
- * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data.
224
- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
225
- * not compatible
226
- */
227
166
private static TFloat64 buildDouble (SharedMemoryArray tensor )
228
167
throws IllegalArgumentException
229
168
{
0 commit comments