28
28
import java .nio .ByteBuffer ;
29
29
import java .util .Arrays ;
30
30
31
- import org .tensorflow .Tensor ;
32
31
import org .tensorflow .types .TFloat32 ;
33
32
import org .tensorflow .types .TFloat64 ;
34
33
import org .tensorflow .types .TInt32 ;
@@ -101,13 +100,8 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
101
100
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
102
101
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
103
102
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
104
- ByteBuffer buff = shma .getDataBuffer ();
105
- int totalSize = 1 ;
106
- for (long i : arrayShape ) {totalSize *= i ;}
107
- byte [] flatArr = new byte [buff .capacity ()];
108
- buff .get (flatArr );
109
- tensor .asRawTensor ().data ().read (flatArr , flatArr .length - totalSize , totalSize );
110
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
103
+ ByteBuffer buff = shma .getDataBufferNoHeader ();
104
+ tensor .asRawTensor ().data ().read (buff .array (), 0 , buff .capacity ());
111
105
if (PlatformDetection .isWindows ()) shma .close ();
112
106
}
113
107
@@ -119,13 +113,8 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws
119
113
+ " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
120
114
121
115
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new IntType (), false , true );
122
- ByteBuffer buff = shma .getDataBuffer ();
123
- int totalSize = 4 ;
124
- for (long i : arrayShape ) {totalSize *= i ;}
125
- byte [] flatArr = new byte [buff .capacity ()];
126
- buff .get (flatArr );
127
- tensor .asRawTensor ().data ().read (flatArr , flatArr .length - totalSize , totalSize );
128
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
116
+ ByteBuffer buff = shma .getDataBufferNoHeader ();
117
+ tensor .asRawTensor ().data ().read (buff .array (), 0 , buff .capacity ());
129
118
if (PlatformDetection .isWindows ()) shma .close ();
130
119
}
131
120
@@ -137,13 +126,8 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
137
126
+ " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
138
127
139
128
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new FloatType (), false , true );
140
- ByteBuffer buff = shma .getDataBuffer ();
141
- int totalSize = 4 ;
142
- for (long i : arrayShape ) {totalSize *= i ;}
143
- byte [] flatArr = new byte [buff .capacity ()];
144
- buff .get (flatArr );
145
- tensor .asRawTensor ().data ().read (flatArr , flatArr .length - totalSize , totalSize );
146
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
129
+ ByteBuffer buff = shma .getDataBufferNoHeader ();
130
+ tensor .asRawTensor ().data ().read (buff .array (), 0 , buff .capacity ());
147
131
if (PlatformDetection .isWindows ()) shma .close ();
148
132
}
149
133
@@ -155,13 +139,8 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th
155
139
+ " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
156
140
157
141
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new DoubleType (), false , true );
158
- ByteBuffer buff = shma .getDataBuffer ();
159
- int totalSize = 8 ;
160
- for (long i : arrayShape ) {totalSize *= i ;}
161
- byte [] flatArr = new byte [buff .capacity ()];
162
- buff .get (flatArr );
163
- tensor .asRawTensor ().data ().read (flatArr , flatArr .length - totalSize , totalSize );
164
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
142
+ ByteBuffer buff = shma .getDataBufferNoHeader ();
143
+ tensor .asRawTensor ().data ().read (buff .array (), 0 , buff .capacity ());
165
144
if (PlatformDetection .isWindows ()) shma .close ();
166
145
}
167
146
@@ -174,13 +153,8 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws
174
153
175
154
176
155
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new LongType (), false , true );
177
- ByteBuffer buff = shma .getDataBuffer ();
178
- int totalSize = 8 ;
179
- for (long i : arrayShape ) {totalSize *= i ;}
180
- byte [] flatArr = new byte [buff .capacity ()];
181
- buff .get (flatArr );
182
- tensor .asRawTensor ().data ().read (flatArr , flatArr .length - totalSize , totalSize );
183
- shma .setBuffer (ByteBuffer .wrap (flatArr ));
156
+ ByteBuffer buff = shma .getDataBufferNoHeader ();
157
+ tensor .asRawTensor ().data ().read (buff .array (), 0 , buff .capacity ());
184
158
if (PlatformDetection .isWindows ()) shma .close ();
185
159
}
186
160
}
0 commit comments