20
20
*/
21
21
package io .bioimage .modelrunner .pytorch .javacpp .shm ;
22
22
23
- import io .bioimage .modelrunner .pytorch .javacpp .tensor .ImgLib2Builder ;
24
23
import io .bioimage .modelrunner .system .PlatformDetection ;
25
24
import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
26
25
import io .bioimage .modelrunner .utils .CommonUtils ;
27
26
28
27
import java .io .IOException ;
29
28
import java .nio .ByteBuffer ;
29
+ import java .nio .DoubleBuffer ;
30
30
import java .nio .FloatBuffer ;
31
+ import java .nio .IntBuffer ;
32
+ import java .nio .LongBuffer ;
31
33
import java .util .Arrays ;
32
34
33
35
import org .bytedeco .pytorch .Tensor ;
34
36
35
37
import net .imglib2 .type .numeric .integer .IntType ;
36
38
import net .imglib2 .type .numeric .integer .LongType ;
37
- import net .imglib2 .RandomAccessibleInterval ;
38
39
import net .imglib2 .type .numeric .integer .ByteType ;
39
40
import net .imglib2 .type .numeric .real .DoubleType ;
40
41
import net .imglib2 .type .numeric .real .FloatType ;
@@ -88,7 +89,14 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
88
89
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
89
90
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
90
91
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new ByteType (), false , true );
91
- shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
92
+ long flatSize = 1 ;
93
+ for (long l : arrayShape ) {flatSize *= l ;}
94
+ byte [] flat = new byte [(int ) flatSize ];
95
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize ));
96
+ tensor .data_ptr_byte ().get (flat );
97
+ byteBuffer .put (flat );
98
+ byteBuffer .rewind ();
99
+ shma .getDataBufferNoHeader ().put (byteBuffer );
92
100
if (PlatformDetection .isWindows ()) shma .close ();
93
101
}
94
102
@@ -99,8 +107,15 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
99
107
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
100
108
+ " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
101
109
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new IntType (), false , true );
102
- RandomAccessibleInterval <?> rai = shma .getSharedRAI ();
103
- rai = ImgLib2Builder .build (tensor );
110
+ long flatSize = 1 ;
111
+ for (long l : arrayShape ) {flatSize *= l ;}
112
+ int [] flat = new int [(int ) flatSize ];
113
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Integer .BYTES ));
114
+ IntBuffer floatBuffer = byteBuffer .asIntBuffer ();
115
+ tensor .data_ptr_int ().get (flat );
116
+ floatBuffer .put (flat );
117
+ byteBuffer .rewind ();
118
+ shma .getDataBufferNoHeader ().put (byteBuffer );
104
119
if (PlatformDetection .isWindows ()) shma .close ();
105
120
}
106
121
@@ -130,7 +145,15 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
130
145
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
131
146
+ " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
132
147
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new DoubleType (), false , true );
133
- shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
148
+ long flatSize = 1 ;
149
+ for (long l : arrayShape ) {flatSize *= l ;}
150
+ double [] flat = new double [(int ) flatSize ];
151
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Double .BYTES ));
152
+ DoubleBuffer floatBuffer = byteBuffer .asDoubleBuffer ();
153
+ tensor .data_ptr_double ().get (flat );
154
+ floatBuffer .put (flat );
155
+ byteBuffer .rewind ();
156
+ shma .getDataBufferNoHeader ().put (byteBuffer );
134
157
if (PlatformDetection .isWindows ()) shma .close ();
135
158
}
136
159
@@ -141,7 +164,15 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
141
164
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
142
165
+ " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
143
166
SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new LongType (), false , true );
144
- shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
167
+ long flatSize = 1 ;
168
+ for (long l : arrayShape ) {flatSize *= l ;}
169
+ long [] flat = new long [(int ) flatSize ];
170
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Long .BYTES ));
171
+ LongBuffer floatBuffer = byteBuffer .asLongBuffer ();
172
+ tensor .data_ptr_long ().get (flat );
173
+ floatBuffer .put (flat );
174
+ byteBuffer .rewind ();
175
+ shma .getDataBufferNoHeader ().put (byteBuffer );
145
176
if (PlatformDetection .isWindows ()) shma .close ();
146
177
}
147
178
}
0 commit comments