@@ -69,19 +69,14 @@ public static void build(Tensor tensor, String memoryName) throws IllegalArgumen
69
69
{
70
70
if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Byte )
71
71
|| tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Char )) {
72
- System .out .println ("SSECRET_KEY : BYTE " );
73
72
buildFromTensorByte (tensor , memoryName );
74
73
} else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Int )) {
75
- System .out .println ("SSECRET_KEY : INT " );
76
74
buildFromTensorInt (tensor , memoryName );
77
75
} else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Float )) {
78
- System .out .println ("SSECRET_KEY : FLOAT " );
79
76
buildFromTensorFloat (tensor , memoryName );
80
77
} else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Double )) {
81
- System .out .println ("SSECRET_KEY : SOUBKE " );
82
78
buildFromTensorDouble (tensor , memoryName );
83
79
} else if (tensor .dtype ().isScalarType (org .bytedeco .pytorch .global .torch .ScalarType .Long )) {
84
- System .out .println ("SSECRET_KEY : LONG " );
85
80
buildFromTensorLong (tensor , memoryName );
86
81
} else {
87
82
throw new IllegalArgumentException ("Unsupported tensor type: " + tensor .scalar_type ());
@@ -98,10 +93,9 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
98
93
long flatSize = 1 ;
99
94
for (long l : arrayShape ) {flatSize *= l ;}
100
95
byte [] flat = new byte [(int ) flatSize ];
101
- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize ));
96
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize )). order ( ByteOrder . LITTLE_ENDIAN ) ;
102
97
tensor .data_ptr_byte ().get (flat );
103
98
byteBuffer .put (flat );
104
- byteBuffer .rewind ();
105
99
shma .getDataBufferNoHeader ().put (byteBuffer );
106
100
if (PlatformDetection .isWindows ()) shma .close ();
107
101
}
@@ -116,11 +110,10 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
116
110
long flatSize = 1 ;
117
111
for (long l : arrayShape ) {flatSize *= l ;}
118
112
int [] flat = new int [(int ) flatSize ];
119
- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Integer .BYTES ));
113
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Integer .BYTES )). order ( ByteOrder . LITTLE_ENDIAN ) ;
120
114
IntBuffer floatBuffer = byteBuffer .asIntBuffer ();
121
115
tensor .data_ptr_int ().get (flat );
122
116
floatBuffer .put (flat );
123
- byteBuffer .rewind ();
124
117
shma .getDataBufferNoHeader ().put (byteBuffer );
125
118
if (PlatformDetection .isWindows ()) shma .close ();
126
119
}
@@ -140,10 +133,6 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
140
133
tensor .data_ptr_float ().get (flat );
141
134
floatBuffer .put (flat );
142
135
shma .getDataBufferNoHeader ().put (byteBuffer );
143
- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (100 ) == byteBuffer .get (100 )));
144
- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (500 ) == byteBuffer .get (500 )));
145
- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (300 ) == byteBuffer .get (300 )));
146
- System .out .println ("equals " + (shma .getDataBufferNoHeader ().get (1000 ) == byteBuffer .get (1000 )));
147
136
if (PlatformDetection .isWindows ()) shma .close ();
148
137
}
149
138
@@ -157,11 +146,10 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
157
146
long flatSize = 1 ;
158
147
for (long l : arrayShape ) {flatSize *= l ;}
159
148
double [] flat = new double [(int ) flatSize ];
160
- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Double .BYTES ));
149
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Double .BYTES )). order ( ByteOrder . LITTLE_ENDIAN ) ;
161
150
DoubleBuffer floatBuffer = byteBuffer .asDoubleBuffer ();
162
151
tensor .data_ptr_double ().get (flat );
163
152
floatBuffer .put (flat );
164
- byteBuffer .rewind ();
165
153
shma .getDataBufferNoHeader ().put (byteBuffer );
166
154
if (PlatformDetection .isWindows ()) shma .close ();
167
155
}
@@ -176,11 +164,10 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
176
164
long flatSize = 1 ;
177
165
for (long l : arrayShape ) {flatSize *= l ;}
178
166
long [] flat = new long [(int ) flatSize ];
179
- ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Long .BYTES ));
167
+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Long .BYTES )). order ( ByteOrder . LITTLE_ENDIAN ) ;
180
168
LongBuffer floatBuffer = byteBuffer .asLongBuffer ();
181
169
tensor .data_ptr_long ().get (flat );
182
170
floatBuffer .put (flat );
183
- byteBuffer .rewind ();
184
171
shma .getDataBufferNoHeader ().put (byteBuffer );
185
172
if (PlatformDetection .isWindows ()) shma .close ();
186
173
}
0 commit comments