29
29
import java .util .Arrays ;
30
30
31
31
import org .tensorflow .Tensor ;
32
- import org .tensorflow .types .TFloat32 ;
33
- import org .tensorflow .types .TFloat64 ;
34
- import org .tensorflow .types .TInt32 ;
35
- import org .tensorflow .types .TInt64 ;
36
- import org .tensorflow .types .TUint8 ;
37
- import org .tensorflow .types .family .TType ;
32
+ import org .tensorflow .types .UInt8 ;
38
33
39
34
import net .imglib2 .RandomAccessibleInterval ;
40
35
import net .imglib2 .type .numeric .integer .IntType ;
@@ -70,20 +65,20 @@ private ShmBuilder()
70
65
* @throws IOException
71
66
*/
72
67
@ SuppressWarnings ("unchecked" )
73
- public static void build (Tensor <? extends TType > tensor , String memoryName ) throws IllegalArgumentException , IOException
68
+ public static void build (Tensor <?> tensor , String memoryName ) throws IllegalArgumentException , IOException
74
69
{
75
- switch (tensor .dataType (). name () )
70
+ switch (tensor .dataType ())
76
71
{
77
- case TUint8 . NAME :
78
- buildFromTensorUByte ((Tensor <TUint8 >) tensor , memoryName );
79
- case TInt32 . NAME :
80
- buildFromTensorInt ((Tensor <TInt32 >) tensor , memoryName );
81
- case TFloat32 . NAME :
82
- buildFromTensorFloat ((Tensor <TFloat32 >) tensor , memoryName );
83
- case TFloat64 . NAME :
84
- buildFromTensorDouble ((Tensor <TFloat64 >) tensor , memoryName );
85
- case TInt64 . NAME :
86
- buildFromTensorLong ((Tensor <TInt64 >) tensor , memoryName );
72
+ case UINT8 :
73
+ buildFromTensorUByte ((Tensor <UInt8 >) tensor , memoryName );
74
+ case INT32 :
75
+ buildFromTensorInt ((Tensor <Integer >) tensor , memoryName );
76
+ case FLOAT :
77
+ buildFromTensorFloat ((Tensor <Float >) tensor , memoryName );
78
+ case DOUBLE :
79
+ buildFromTensorDouble ((Tensor <Double >) tensor , memoryName );
80
+ case INT64 :
81
+ buildFromTensorLong ((Tensor <Long >) tensor , memoryName );
87
82
default :
88
83
throw new IllegalArgumentException ("Unsupported tensor type: " + tensor .dataType ().name ());
89
84
}
@@ -97,20 +92,15 @@ public static void build(Tensor<? extends TType> tensor, String memoryName) thr
97
92
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link UnsignedByteType}.
98
93
* @throws IOException
99
94
*/
100
- private static void buildFromTensorUByte (Tensor <TUint8 > tensor , String memoryName ) throws IOException
95
+ private static void buildFromTensorUByte (Tensor <UInt8 > tensor , String memoryName ) throws IOException
101
96
{
102
- long [] arrayShape = tensor .shape (). asArray () ;
97
+ long [] arrayShape = tensor .shape ();
103
98
if (CommonUtils .int32Overflows (arrayShape , 1 ))
104
99
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
105
100
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
106
101
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
107
102
ByteBuffer buff = shma .getDataBuffer ();
108
- int totalSize = 1 ;
109
- for (long i : arrayShape ) {totalSize *= i ;}
110
- byte [] flatArr = new byte [buff .capacity ()];
111
- buff .get (flatArr );
112
- tensor .rawData ().read (flatArr , flatArr .length - totalSize , totalSize );
113
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
103
+ tensor .writeTo (buff );
114
104
if (PlatformDetection .isWindows ()) shma .close ();
115
105
}
116
106
@@ -122,21 +112,16 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
122
112
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link IntType}.
123
113
* @throws IOException
124
114
*/
125
- private static void buildFromTensorInt (Tensor <TInt32 > tensor , String memoryName ) throws IOException
115
+ private static void buildFromTensorInt (Tensor <Integer > tensor , String memoryName ) throws IOException
126
116
{
127
- long [] arrayShape = tensor .shape (). asArray () ;
117
+ long [] arrayShape = tensor .shape ();
128
118
if (CommonUtils .int32Overflows (arrayShape , 4 ))
129
119
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
130
120
+ " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
131
121
132
122
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new IntType (), false , true );
133
123
ByteBuffer buff = shma .getDataBuffer ();
134
- int totalSize = 4 ;
135
- for (long i : arrayShape ) {totalSize *= i ;}
136
- byte [] flatArr = new byte [buff .capacity ()];
137
- buff .get (flatArr );
138
- tensor .rawData ().read (flatArr , flatArr .length - totalSize , totalSize );
139
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
124
+ tensor .writeTo (buff );
140
125
if (PlatformDetection .isWindows ()) shma .close ();
141
126
}
142
127
@@ -148,21 +133,16 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
148
133
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link FloatType}.
149
134
* @throws IOException
150
135
*/
151
- private static void buildFromTensorFloat (Tensor <TFloat32 > tensor , String memoryName ) throws IOException
136
+ private static void buildFromTensorFloat (Tensor <Float > tensor , String memoryName ) throws IOException
152
137
{
153
- long [] arrayShape = tensor .shape (). asArray () ;
138
+ long [] arrayShape = tensor .shape ();
154
139
if (CommonUtils .int32Overflows (arrayShape , 4 ))
155
140
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
156
141
+ " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
157
142
158
143
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new FloatType (), false , true );
159
144
ByteBuffer buff = shma .getDataBuffer ();
160
- int totalSize = 4 ;
161
- for (long i : arrayShape ) {totalSize *= i ;}
162
- byte [] flatArr = new byte [buff .capacity ()];
163
- buff .get (flatArr );
164
- tensor .rawData ().read (flatArr , flatArr .length - totalSize , totalSize );
165
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
145
+ tensor .writeTo (buff );
166
146
if (PlatformDetection .isWindows ()) shma .close ();
167
147
}
168
148
@@ -174,21 +154,16 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
174
154
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link DoubleType}.
175
155
* @throws IOException
176
156
*/
177
- private static void buildFromTensorDouble (Tensor <TFloat64 > tensor , String memoryName ) throws IOException
157
+ private static void buildFromTensorDouble (Tensor <Double > tensor , String memoryName ) throws IOException
178
158
{
179
- long [] arrayShape = tensor .shape (). asArray () ;
159
+ long [] arrayShape = tensor .shape ();
180
160
if (CommonUtils .int32Overflows (arrayShape , 8 ))
181
161
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
182
162
+ " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
183
163
184
164
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new DoubleType (), false , true );
185
165
ByteBuffer buff = shma .getDataBuffer ();
186
- int totalSize = 8 ;
187
- for (long i : arrayShape ) {totalSize *= i ;}
188
- byte [] flatArr = new byte [buff .capacity ()];
189
- buff .get (flatArr );
190
- tensor .rawData ().read (flatArr , flatArr .length - totalSize , totalSize );
191
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
166
+ tensor .writeTo (buff );
192
167
if (PlatformDetection .isWindows ()) shma .close ();
193
168
}
194
169
@@ -200,22 +175,17 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
200
175
* @return The {@link RandomAccessibleInterval} built from the tensor, of type {@link LongType}.
201
176
* @throws IOException
202
177
*/
203
- private static void buildFromTensorLong (Tensor <TInt64 > tensor , String memoryName ) throws IOException
178
+ private static void buildFromTensorLong (Tensor <Long > tensor , String memoryName ) throws IOException
204
179
{
205
- long [] arrayShape = tensor .shape (). asArray () ;
180
+ long [] arrayShape = tensor .shape ();
206
181
if (CommonUtils .int32Overflows (arrayShape , 8 ))
207
182
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
208
183
+ " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
209
184
210
185
211
186
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new LongType (), false , true );
212
187
ByteBuffer buff = shma .getDataBuffer ();
213
- int totalSize = 8 ;
214
- for (long i : arrayShape ) {totalSize *= i ;}
215
- byte [] flatArr = new byte [buff .capacity ()];
216
- buff .get (flatArr );
217
- tensor .rawData ().read (flatArr , flatArr .length - totalSize , totalSize );
218
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
188
+ tensor .writeTo (buff );
219
189
if (PlatformDetection .isWindows ()) shma .close ();
220
190
}
221
191
}
0 commit comments