26
26
import net .imglib2 .util .Cast ;
27
27
28
28
import java .nio .ByteBuffer ;
29
- import java .nio .DoubleBuffer ;
30
- import java .nio .FloatBuffer ;
31
- import java .nio .IntBuffer ;
32
- import java .nio .LongBuffer ;
33
29
import java .util .Arrays ;
34
30
35
31
import org .tensorflow .Tensor ;
@@ -102,7 +98,10 @@ private static TUint8 buildUByte(SharedMemoryArray tensor)
102
98
if (!tensor .isNumpyFormat ())
103
99
throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
104
100
ByteBuffer buff = tensor .getDataBufferNoHeader ();
105
- ByteDataBuffer dataBuffer = RawDataBufferFactory .create (buff .array (), false );
101
+ byte [] flat = new byte [buff .capacity ()];
102
+ buff .get (flat );
103
+ buff .rewind ();
104
+ ByteDataBuffer dataBuffer = RawDataBufferFactory .create (flat , false );
106
105
TUint8 ndarray = Tensor .of (TUint8 .class , Shape .of (ogShape ), dataBuffer );
107
106
return ndarray ;
108
107
}
@@ -117,8 +116,10 @@ private static TInt32 buildInt(SharedMemoryArray tensor)
117
116
if (!tensor .isNumpyFormat ())
118
117
throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
119
118
ByteBuffer buff = tensor .getDataBufferNoHeader ();
120
- IntBuffer intBuff = buff .asIntBuffer ();
121
- IntDataBuffer dataBuffer = RawDataBufferFactory .create (intBuff .array (), false );
119
+ int [] flat = new int [buff .capacity () / 4 ];
120
+ buff .asIntBuffer ().get (flat );
121
+ buff .rewind ();
122
+ IntDataBuffer dataBuffer = RawDataBufferFactory .create (flat , false );
122
123
TInt32 ndarray = TInt32 .tensorOf (Shape .of (ogShape ),
123
124
dataBuffer );
124
125
return ndarray ;
@@ -134,8 +135,10 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
134
135
if (!tensor .isNumpyFormat ())
135
136
throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
136
137
ByteBuffer buff = tensor .getDataBufferNoHeader ();
137
- LongBuffer longBuff = buff .asLongBuffer ();
138
- LongDataBuffer dataBuffer = RawDataBufferFactory .create (longBuff .array (), false );
138
+ long [] flat = new long [buff .capacity () / 8 ];
139
+ buff .asLongBuffer ().get (flat );
140
+ buff .rewind ();
141
+ LongDataBuffer dataBuffer = RawDataBufferFactory .create (flat , false );
139
142
TInt64 ndarray = TInt64 .tensorOf (Shape .of (ogShape ),
140
143
dataBuffer );
141
144
return ndarray ;
@@ -151,8 +154,10 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
151
154
if (!tensor .isNumpyFormat ())
152
155
throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
153
156
ByteBuffer buff = tensor .getDataBufferNoHeader ();
154
- FloatBuffer floatBuff = buff .asFloatBuffer ();
155
- FloatDataBuffer dataBuffer = RawDataBufferFactory .create (floatBuff .array (), false );
157
+ float [] flat = new float [buff .capacity () / 4 ];
158
+ buff .asFloatBuffer ().get (flat );
159
+ buff .rewind ();
160
+ FloatDataBuffer dataBuffer = RawDataBufferFactory .create (flat , false );
156
161
TFloat32 ndarray = TFloat32 .tensorOf (Shape .of (ogShape ), dataBuffer );
157
162
return ndarray ;
158
163
}
@@ -167,8 +172,10 @@ private static TFloat64 buildDouble(SharedMemoryArray tensor)
167
172
if (!tensor .isNumpyFormat ())
168
173
throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
169
174
ByteBuffer buff = tensor .getDataBufferNoHeader ();
170
- DoubleBuffer floatBuff = buff .asDoubleBuffer ();
171
- DoubleDataBuffer dataBuffer = RawDataBufferFactory .create (floatBuff .array (), false );
175
+ double [] flat = new double [buff .capacity () / 8 ];
176
+ buff .asDoubleBuffer ().get (flat );
177
+ buff .rewind ();
178
+ DoubleDataBuffer dataBuffer = RawDataBufferFactory .create (flat , false );
172
179
TFloat64 ndarray = TFloat64 .tensorOf (Shape .of (ogShape ), dataBuffer );
173
180
return ndarray ;
174
181
}
0 commit comments