1111#include < torch/csrc/CudaIPCTypes.h>
1212#include < torch/csrc/Device.h>
1313#include < torch/csrc/autograd/utils/wrap_outputs.h>
14+ #include < torch/csrc/utils/python_arg_parser.h>
15+ #include < torch/csrc/StorageMethods.h>
16+ #include < torch/csrc/StorageSharing.h>
1417#include < c10/core/CPUAllocator.h>
1518
1619#include < fmt/format.h>
17-
18- // NOLINTNEXTLINE(bugprone-suspicious-include)
19- #include < torch/csrc/generic/Storage.cpp>
20- #include < torch/csrc/THGenerateByteType.h>
21-
2220#include < c10/util/intrusive_ptr.h>
2321
2422template <>
@@ -27,3 +25,358 @@ void THPPointer<c10::StorageImpl>::free() {
2725 c10::raw::intrusive_ptr::decref (ptr);
2826 }
2927}
28+
29+ PyObject *THPStorageClass = nullptr ;
30+
31+ PyObject * THPStorage_New (c10::intrusive_ptr<c10::StorageImpl> ptr)
32+ {
33+ AT_ASSERT (ptr);
34+ PyTypeObject *type = (PyTypeObject *)THPStorageClass;
35+ PyObject *obj = type->tp_alloc (type, 0 );
36+ if (obj) {
37+ ((THPStorage *)obj)->cdata = ptr.release ();
38+ }
39+ return obj;
40+ }
41+
42+ static void THPStorage_dealloc (THPStorage* self)
43+ {
44+ if (self->cdata ) {
45+ c10::raw::intrusive_ptr::decref (self->cdata );
46+ }
47+ Py_TYPE (self)->tp_free ((PyObject*)self);
48+ }
49+
50+ static PyObject * THPStorage_pynew (PyTypeObject *type, PyObject *args, PyObject *kwargs)
51+ {
52+ HANDLE_TH_ERRORS
53+
54+ static torch::PythonArgParser parser ({
55+ THPStorageStr " (*, int64_t allocator=None, Device device=None)" ,
56+ THPStorageStr " (int64_t size, *, int64_t allocator=None, Device device=None)" ,
57+ THPStorageStr " (PyObject* sequence, *, int64_t allocator=None, Device device=None)" ,
58+ });
59+ torch::ParsedArgs<3 > parsed_args;
60+ auto r = parser.parse (args, kwargs, parsed_args);
61+
62+ int64_t allocator_arg_idx = 0 ;
63+ int64_t device_arg_idx = 1 ;
64+
65+ if (r.idx > 0 ) {
66+ allocator_arg_idx = 1 ;
67+ device_arg_idx = 2 ;
68+ }
69+
70+ c10::optional<int64_t > allocator_opt = r.toInt64Optional (allocator_arg_idx);
71+ c10::optional<at::Device> device_opt = r.deviceOptional (device_arg_idx);
72+
73+ TORCH_CHECK (!allocator_opt.has_value () || !device_opt.has_value (),
74+ THPStorageStr, " (): only one or neither of 'allocator' or 'device' can " ,
75+ " be given, but not both" );
76+
77+ THPStoragePtr self ((THPStorage *)type->tp_alloc (type, 0 ));
78+ THPUtils_assert (self, " failed to allocate a " THPStorageStr " object" );
79+ c10::Allocator* allocator = nullptr ;
80+ at::OptionalDeviceGuard device_guard;
81+
82+ if (allocator_opt.has_value ()) {
83+ allocator = reinterpret_cast <c10::Allocator*>(allocator_opt.value ());
84+ } else if (device_opt.has_value ()) {
85+ at::Device device = device_opt.value ();
86+ if (device.type () == at::kCPU ) {
87+ allocator = c10::GetDefaultCPUAllocator ();
88+ #ifdef USE_CUDA
89+ } else if (device.type () == at::kCUDA ) {
90+ at::globalContext ().lazyInitCUDA ();
91+ allocator = c10::cuda::CUDACachingAllocator::get ();
92+ #endif
93+ } else if (device.type () == at::DeviceType::Meta) {
94+ allocator = c10::GetAllocator (device.type ());
95+ } else {
96+ TORCH_CHECK (false ,
97+ THPStorageStr, " (): Storage device not recognized: " , device.type ());
98+ }
99+ device_guard.reset_device (device);
100+ } else {
101+ allocator = c10::GetDefaultCPUAllocator ();
102+ }
103+
104+ // torch.Storage(*, ...)
105+ if (r.idx == 0 ) {
106+ self->cdata = c10::make_intrusive<at::StorageImpl>(
107+ c10::StorageImpl::use_byte_size_t (),
108+ 0 ,
109+ allocator,
110+ /* resizable=*/ true ).release ();
111+ return (PyObject*)self.release ();
112+
113+ // torch.Storage(size, *, ...)
114+ } else if (r.idx == 1 ) {
115+ int64_t size = r.toInt64 (0 );
116+ self->cdata = c10::make_intrusive<at::StorageImpl>(
117+ c10::StorageImpl::use_byte_size_t (),
118+ size,
119+ allocator,
120+ /* resizable=*/ true ).release ();
121+ return (PyObject*)self.release ();
122+
123+ // torch.Storage(sequence, *, ...)
124+ } else if (r.idx == 2 ) {
125+ PyObject *sequence = r.pyobject (0 );
126+ Py_ssize_t length = PySequence_Length (sequence);
127+ TORCH_CHECK (PySequence_Check (sequence),
128+ THPStorageStr, " (): Expected a sequence type, but got " ,
129+ THPUtils_typename (sequence));
130+ TORCH_CHECK (length >= 0 ,
131+ THPStorageStr, " (): Could not obtain the length of sequence of type " ,
132+ THPUtils_typename (sequence));
133+ self->cdata = c10::make_intrusive<at::StorageImpl>(
134+ c10::StorageImpl::use_byte_size_t (),
135+ length,
136+ allocator,
137+ /* resizable=*/ true )
138+ .release ();
139+ THPObjectPtr item;
140+ try {
141+ for (Py_ssize_t i = 0 ; i < length; i++) {
142+ item = PySequence_GetItem (sequence, i);
143+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
144+ uint8_t value = THPByteUtils_unpackReal (item.get ());
145+ if (allocator == c10::GetDefaultCPUAllocator ()) {
146+ self->cdata ->unsafe_data <uint8_t >()[i] = value;
147+ } else {
148+ // TODO: this might be slow - consider batched updates?
149+ storage_set (
150+ at::unsafeStorageFromTH (self->cdata , /* retain=*/ true ),
151+ i,
152+ value);
153+ }
154+ }
155+ } catch (const std::exception &e) {
156+ THPUtils_setError (THPStorageStr
157+ " (): tried to construct a storage from a sequence (%s), "
158+ " but one of the items was of type %s instead of %s" ,
159+ THPUtils_typename (sequence),
160+ THPUtils_typename (item.get ()),
161+ THPUtils_typeTraits<uint8_t >::python_type_str);
162+ return nullptr ;
163+ }
164+ return (PyObject*)self.release ();
165+ }
166+ Py_RETURN_NONE;
167+ END_HANDLE_TH_ERRORS
168+ }
169+
170+ static Py_ssize_t THPStorage_length (THPStorage *self)
171+ {
172+ HANDLE_TH_ERRORS
173+ return self->cdata ->nbytes () / sizeof (uint8_t );
174+ END_HANDLE_TH_ERRORS_RET (-1 )
175+ }
176+
177+ static PyObject * THPStorage_get (THPStorage *self, PyObject *index)
178+ {
179+ HANDLE_TH_ERRORS
180+ /* Integer index */
181+ if (THPUtils_checkLong (index)) {
182+ int64_t nindex = THPUtils_unpackLong (index);
183+ if (nindex < 0 )
184+ nindex += (self->cdata ->nbytes () / sizeof (uint8_t ));
185+ if (nindex < 0 || nindex >= static_cast <int64_t >(self->cdata ->nbytes () / sizeof (uint8_t ))) {
186+ PyErr_SetString (PyExc_IndexError, fmt::format (
187+ " index {} out of range for storage of size {}" ,
188+ nindex, self->cdata ->nbytes () / sizeof (uint8_t )));
189+ return nullptr ;
190+ }
191+ uint8_t value = storage_get (at::unsafeStorageFromTH (self->cdata , /* retain=*/ true ), nindex);
192+ return THPByteUtils_newReal (value);
193+ /* Slice index */
194+ } else if (PySlice_Check (index)) {
195+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
196+ Py_ssize_t start, stop, slicelength, step;
197+ int64_t len = self->cdata ->nbytes () / sizeof (uint8_t );
198+ if (!THPUtils_parseSlice (index, len, &start, &stop, &step, &slicelength))
199+ return nullptr ;
200+ if (step != 1 ) {
201+ THPUtils_setError (" Trying to slice with a step of %lld, but only a step of "
202+ " 1 is supported" , (long long )step);
203+ return nullptr ;
204+ }
205+
206+ uint8_t *data = self->cdata ->data <uint8_t >();
207+
208+ at::StorageImpl* old_storage = self->cdata ;
209+ c10::raw::intrusive_ptr::incref (old_storage);
210+ auto new_storage = c10::make_intrusive<at::StorageImpl>(
211+ c10::StorageImpl::use_byte_size_t (),
212+ #ifdef THQUANTIZED
213+ slicelength * sizeof (quantized_t ),
214+ #else
215+ slicelength * sizeof (uint8_t ),
216+ #endif
217+ at::DataPtr (
218+ static_cast <void *>(data + start),
219+ old_storage,
220+ [](void * s) {
221+ c10::raw::intrusive_ptr::decref (static_cast <at::StorageImpl*>(s));
222+ },
223+ old_storage->device ()),
224+ old_storage->allocator (),
225+ /* resizable */ false );
226+
227+ PyObject *_ret = THPStorage_New (std::move (new_storage));
228+ return _ret;
229+ }
230+ PyErr_Format (PyExc_TypeError, " can't index a " THPStorageStr " with %s" ,
231+ THPUtils_typename (index));
232+ return nullptr ;
233+ END_HANDLE_TH_ERRORS
234+ }
235+
236+ static int THPStorage_set (THPStorage *self, PyObject *index, PyObject *value)
237+ {
238+ HANDLE_TH_ERRORS
239+ if (!THPByteUtils_checkReal (value)) {
240+ THPUtils_setError (" can only set storage content with a %s, but got "
241+ " %s instead" , THPUtils_typeTraits<uint8_t >::python_type_str,
242+ THPUtils_typename (value));
243+ return -1 ;
244+ }
245+
246+ uint8_t rvalue = THPByteUtils_unpackReal (value);
247+ if (THPUtils_checkLong (index)) {
248+ int64_t nindex = THPUtils_unpackLong (index);
249+ storage_set (
250+ at::unsafeStorageFromTH (self->cdata , /* retain=*/ true ),
251+ nindex,
252+ rvalue);
253+ return 0 ;
254+ } else if (PySlice_Check (index)) {
255+ // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
256+ Py_ssize_t start, stop, slicelength, step;
257+ int64_t len = self->cdata ->nbytes () / sizeof (uint8_t );
258+ if (!THPUtils_parseSlice (index, len, &start, &stop, &step, &slicelength))
259+ return -1 ;
260+ if (step != 1 ) {
261+ THPUtils_setError (" Trying to slice with a step of %lld, but only a step of "
262+ " 1 is supported" , (long long )step);
263+ return 0 ;
264+ }
265+ // TODO: check the bounds only once
266+ // TODO: fill?
267+ for (;start < stop; start++)
268+ storage_set (
269+ at::unsafeStorageFromTH (self->cdata , /* retain=*/ true ),
270+ start,
271+ rvalue);
272+ return 0 ;
273+ }
274+ THPUtils_setError (" can't index a " THPStorageStr " with %s" ,
275+ THPUtils_typename (index));
276+ return -1 ;
277+ END_HANDLE_TH_ERRORS_RET (-1 )
278+ }
279+
280+ static PyMappingMethods THPStorage_mappingmethods = {
281+ (lenfunc)THPStorage_length,
282+ (binaryfunc)THPStorage_get,
283+ (objobjargproc)THPStorage_set
284+ };
285+
286+ // TODO: implement equality
287+ PyTypeObject THPStorageType = {
288+ PyVarObject_HEAD_INIT (nullptr , 0 )
289+ " torch._C." THPStorageBaseStr, /* tp_name */
290+ sizeof (THPStorage), /* tp_basicsize */
291+ 0 , /* tp_itemsize */
292+ (destructor)THPStorage_dealloc, /* tp_dealloc */
293+ 0 , /* tp_vectorcall_offset */
294+ nullptr , /* tp_getattr */
295+ nullptr , /* tp_setattr */
296+ nullptr , /* tp_reserved */
297+ nullptr , /* tp_repr */
298+ nullptr , /* tp_as_number */
299+ nullptr , /* tp_as_sequence */
300+ &THPStorage_mappingmethods, /* tp_as_mapping */
301+ nullptr , /* tp_hash */
302+ nullptr , /* tp_call */
303+ nullptr , /* tp_str */
304+ nullptr , /* tp_getattro */
305+ nullptr , /* tp_setattro */
306+ nullptr , /* tp_as_buffer */
307+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
308+ nullptr , /* tp_doc */
309+ nullptr , /* tp_traverse */
310+ nullptr , /* tp_clear */
311+ nullptr , /* tp_richcompare */
312+ 0 , /* tp_weaklistoffset */
313+ nullptr , /* tp_iter */
314+ nullptr , /* tp_iternext */
315+ nullptr , /* will be assigned in init */ /* tp_methods */
316+ nullptr , /* will be assigned in init */ /* tp_members */
317+ nullptr , /* tp_getset */
318+ nullptr , /* tp_base */
319+ nullptr , /* tp_dict */
320+ nullptr , /* tp_descr_get */
321+ nullptr , /* tp_descr_set */
322+ 0 , /* tp_dictoffset */
323+ nullptr , /* tp_init */
324+ nullptr , /* tp_alloc */
325+ THPStorage_pynew, /* tp_new */
326+ };
327+
328+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
329+ static struct PyMemberDef THPStorage_members[] = {
330+ {(char *)" _cdata" , T_ULONGLONG, offsetof (THPStorage, cdata), READONLY, nullptr },
331+ {nullptr }
332+ };
333+
334+ static PyObject * THPStorage_device (THPStorage* self, void *unused) {
335+ HANDLE_TH_ERRORS
336+ return THPDevice_New (self->cdata ->device ());
337+ END_HANDLE_TH_ERRORS
338+ }
339+
340+ static PyObject * THPStorage_dtype (THPStorage *self, void *unused)
341+ {
342+ HANDLE_TH_ERRORS
343+ return torch::autograd::utils::wrap (
344+ torch::getTHPDtype (at::typeMetaToScalarType (
345+ #ifdef THQUANTIZED
346+ caffe2::TypeMeta::Make<quantized_t >()
347+ #else
348+ caffe2::TypeMeta::Make<uint8_t >()
349+ #endif
350+ )));
351+ END_HANDLE_TH_ERRORS
352+ }
353+
354+ typedef PyObject *(*getter)(PyObject *, void *);
355+
356+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
357+ static struct PyGetSetDef THPStorage_properties[] = {
358+ {" device" , (getter)THPStorage_device, nullptr , nullptr , nullptr },
359+ {nullptr }
360+ };
361+
362+ bool THPStorage_init (PyObject *module )
363+ {
364+ static std::vector<PyMethodDef> methods;
365+ THPUtils_addPyMethodDefs (methods, THPStorage_getMethods ());
366+ THPUtils_addPyMethodDefs (methods, THPStorage_getSharingMethods ());
367+
368+ THPStorageType.tp_methods = methods.data ();
369+ THPStorageType.tp_members = THPStorage_members;
370+ THPStorageType.tp_getset = THPStorage_properties;
371+ if (PyType_Ready (&THPStorageType) < 0 )
372+ return false ;
373+ Py_INCREF (&THPStorageType);
374+ PyModule_AddObject (module , THPStorageBaseStr, (PyObject *)&THPStorageType);
375+ return true ;
376+ }
377+
378+ void THPStorage_postInit (PyObject *module )
379+ {
380+ THPStorageClass = PyObject_GetAttrString (module , " _UntypedStorage" );
381+ if (!THPStorageClass) throw python_error ();
382+ }
0 commit comments