29
29
30
30
import org .bytedeco .pytorch .Tensor ;
31
31
32
- import net .imglib2 .type .numeric .integer .UnsignedByteType ;
32
+ import net .imglib2 .type .numeric .integer .IntType ;
33
+ import net .imglib2 .type .numeric .integer .LongType ;
34
+ import net .imglib2 .type .numeric .integer .ByteType ;
35
+ import net .imglib2 .type .numeric .real .DoubleType ;
36
+ import net .imglib2 .type .numeric .real .FloatType ;
33
37
34
38
/**
35
39
* A utility class that converts {@link Tensor}s into {@link SharedMemoryArray}s for
@@ -79,7 +83,7 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
79
83
if (CommonUtils .int32Overflows (arrayShape , 1 ))
80
84
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
81
85
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
82
- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
86
+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new ByteType (), false , true );
83
87
shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
84
88
if (PlatformDetection .isWindows ()) shma .close ();
85
89
}
@@ -90,7 +94,7 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
90
94
if (CommonUtils .int32Overflows (arrayShape , 4 ))
91
95
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
92
96
+ " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
93
- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
97
+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new IntType (), false , true );
94
98
shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
95
99
if (PlatformDetection .isWindows ()) shma .close ();
96
100
}
@@ -101,7 +105,7 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
101
105
if (CommonUtils .int32Overflows (arrayShape , 4 ))
102
106
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
103
107
+ " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
104
- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
108
+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new FloatType (), false , true );
105
109
shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
106
110
if (PlatformDetection .isWindows ()) shma .close ();
107
111
}
@@ -112,7 +116,7 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
112
116
if (CommonUtils .int32Overflows (arrayShape , 8 ))
113
117
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
114
118
+ " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
115
- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
119
+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new DoubleType (), false , true );
116
120
shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
117
121
if (PlatformDetection .isWindows ()) shma .close ();
118
122
}
@@ -123,7 +127,7 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
123
127
if (CommonUtils .int32Overflows (arrayShape , 8 ))
124
128
throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
125
129
+ " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
126
- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
130
+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new LongType (), false , true );
127
131
shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
128
132
if (PlatformDetection .isWindows ()) shma .close ();
129
133
}
0 commit comments