Skip to content

Commit 272193d

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Move THPStorage definitions out of torch/csrc/generic (pytorch#78032)
Fixes pytorch#77908 Pull Request resolved: pytorch#78032 Approved by: https://github.com/ezyang
1 parent 6a4997e commit 272193d

15 files changed

+724
-706
lines changed

tools/build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,8 @@ libtorch_python_core_sources = [
843843
"torch/csrc/python_dimname.cpp",
844844
"torch/csrc/Size.cpp",
845845
"torch/csrc/Storage.cpp",
846+
"torch/csrc/StorageMethods.cpp",
847+
"torch/csrc/StorageSharing.cpp",
846848
"torch/csrc/Stream.cpp",
847849
"torch/csrc/TypeInfo.cpp",
848850
"torch/csrc/api/src/python/init.cpp",

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ class ThroughputBenchmark(object):
864864
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
865865
def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
866866

867-
# Defined in torch/csrc/generic/Storage.cpp
867+
# Defined in torch/csrc/Storage.cpp
868868
${legacy_storage_base_hints}
869869

870870
# TODO: where

torch/csrc/Storage.cpp

Lines changed: 358 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
#include <torch/csrc/CudaIPCTypes.h>
1212
#include <torch/csrc/Device.h>
1313
#include <torch/csrc/autograd/utils/wrap_outputs.h>
14+
#include <torch/csrc/utils/python_arg_parser.h>
15+
#include <torch/csrc/StorageMethods.h>
16+
#include <torch/csrc/StorageSharing.h>
1417
#include <c10/core/CPUAllocator.h>
1518

1619
#include <fmt/format.h>
17-
18-
// NOLINTNEXTLINE(bugprone-suspicious-include)
19-
#include <torch/csrc/generic/Storage.cpp>
20-
#include <torch/csrc/THGenerateByteType.h>
21-
2220
#include <c10/util/intrusive_ptr.h>
2321

2422
template<>
@@ -27,3 +25,358 @@ void THPPointer<c10::StorageImpl>::free() {
2725
c10::raw::intrusive_ptr::decref(ptr);
2826
}
2927
}
28+
29+
PyObject *THPStorageClass = nullptr;
30+
31+
PyObject * THPStorage_New(c10::intrusive_ptr<c10::StorageImpl> ptr)
32+
{
33+
AT_ASSERT(ptr);
34+
PyTypeObject *type = (PyTypeObject *)THPStorageClass;
35+
PyObject *obj = type->tp_alloc(type, 0);
36+
if (obj) {
37+
((THPStorage *)obj)->cdata = ptr.release();
38+
}
39+
return obj;
40+
}
41+
42+
static void THPStorage_dealloc(THPStorage* self)
43+
{
44+
if (self->cdata) {
45+
c10::raw::intrusive_ptr::decref(self->cdata);
46+
}
47+
Py_TYPE(self)->tp_free((PyObject*)self);
48+
}
49+
50+
static PyObject * THPStorage_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
51+
{
52+
HANDLE_TH_ERRORS
53+
54+
static torch::PythonArgParser parser({
55+
THPStorageStr "(*, int64_t allocator=None, Device device=None)",
56+
THPStorageStr "(int64_t size, *, int64_t allocator=None, Device device=None)",
57+
THPStorageStr "(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
58+
});
59+
torch::ParsedArgs<3> parsed_args;
60+
auto r = parser.parse(args, kwargs, parsed_args);
61+
62+
int64_t allocator_arg_idx = 0;
63+
int64_t device_arg_idx = 1;
64+
65+
if (r.idx > 0) {
66+
allocator_arg_idx = 1;
67+
device_arg_idx = 2;
68+
}
69+
70+
c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
71+
c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
72+
73+
TORCH_CHECK(!allocator_opt.has_value() || !device_opt.has_value(),
74+
THPStorageStr, "(): only one or neither of 'allocator' or 'device' can ",
75+
"be given, but not both");
76+
77+
THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0));
78+
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
79+
c10::Allocator* allocator = nullptr;
80+
at::OptionalDeviceGuard device_guard;
81+
82+
if (allocator_opt.has_value()) {
83+
allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
84+
} else if (device_opt.has_value()) {
85+
at::Device device = device_opt.value();
86+
if (device.type() == at::kCPU) {
87+
allocator = c10::GetDefaultCPUAllocator();
88+
#ifdef USE_CUDA
89+
} else if (device.type() == at::kCUDA) {
90+
at::globalContext().lazyInitCUDA();
91+
allocator = c10::cuda::CUDACachingAllocator::get();
92+
#endif
93+
} else if (device.type() == at::DeviceType::Meta) {
94+
allocator = c10::GetAllocator(device.type());
95+
} else {
96+
TORCH_CHECK(false,
97+
THPStorageStr, "(): Storage device not recognized: ", device.type());
98+
}
99+
device_guard.reset_device(device);
100+
} else {
101+
allocator = c10::GetDefaultCPUAllocator();
102+
}
103+
104+
// torch.Storage(*, ...)
105+
if (r.idx == 0) {
106+
self->cdata = c10::make_intrusive<at::StorageImpl>(
107+
c10::StorageImpl::use_byte_size_t(),
108+
0,
109+
allocator,
110+
/*resizable=*/true).release();
111+
return (PyObject*)self.release();
112+
113+
// torch.Storage(size, *, ...)
114+
} else if (r.idx == 1) {
115+
int64_t size = r.toInt64(0);
116+
self->cdata = c10::make_intrusive<at::StorageImpl>(
117+
c10::StorageImpl::use_byte_size_t(),
118+
size,
119+
allocator,
120+
/*resizable=*/true).release();
121+
return (PyObject*)self.release();
122+
123+
// torch.Storage(sequence, *, ...)
124+
} else if (r.idx == 2) {
125+
PyObject *sequence = r.pyobject(0);
126+
Py_ssize_t length = PySequence_Length(sequence);
127+
TORCH_CHECK(PySequence_Check(sequence),
128+
THPStorageStr, "(): Expected a sequence type, but got ",
129+
THPUtils_typename(sequence));
130+
TORCH_CHECK(length >= 0,
131+
THPStorageStr, "(): Could not obtain the length of sequence of type ",
132+
THPUtils_typename(sequence));
133+
self->cdata = c10::make_intrusive<at::StorageImpl>(
134+
c10::StorageImpl::use_byte_size_t(),
135+
length,
136+
allocator,
137+
/*resizable=*/true)
138+
.release();
139+
THPObjectPtr item;
140+
try {
141+
for (Py_ssize_t i = 0; i < length; i++) {
142+
item = PySequence_GetItem(sequence, i);
143+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
144+
uint8_t value = THPByteUtils_unpackReal(item.get());
145+
if (allocator == c10::GetDefaultCPUAllocator()) {
146+
self->cdata->unsafe_data<uint8_t>()[i] = value;
147+
} else {
148+
// TODO: this might be slow - consider batched updates?
149+
storage_set(
150+
at::unsafeStorageFromTH(self->cdata, /*retain=*/true),
151+
i,
152+
value);
153+
}
154+
}
155+
} catch (const std::exception &e) {
156+
THPUtils_setError(THPStorageStr
157+
"(): tried to construct a storage from a sequence (%s), "
158+
"but one of the items was of type %s instead of %s",
159+
THPUtils_typename(sequence),
160+
THPUtils_typename(item.get()),
161+
THPUtils_typeTraits<uint8_t>::python_type_str);
162+
return nullptr;
163+
}
164+
return (PyObject*)self.release();
165+
}
166+
Py_RETURN_NONE;
167+
END_HANDLE_TH_ERRORS
168+
}
169+
170+
static Py_ssize_t THPStorage_length(THPStorage *self)
171+
{
172+
HANDLE_TH_ERRORS
173+
return self->cdata->nbytes() / sizeof(uint8_t);
174+
END_HANDLE_TH_ERRORS_RET(-1)
175+
}
176+
177+
static PyObject * THPStorage_get(THPStorage *self, PyObject *index)
178+
{
179+
HANDLE_TH_ERRORS
180+
/* Integer index */
181+
if (THPUtils_checkLong(index)) {
182+
int64_t nindex = THPUtils_unpackLong(index);
183+
if (nindex < 0)
184+
nindex += (self->cdata->nbytes() / sizeof(uint8_t));
185+
if (nindex < 0 || nindex >= static_cast<int64_t>(self->cdata->nbytes() / sizeof(uint8_t))) {
186+
PyErr_SetString(PyExc_IndexError, fmt::format(
187+
"index {} out of range for storage of size {}",
188+
nindex, self->cdata->nbytes() / sizeof(uint8_t)));
189+
return nullptr;
190+
}
191+
uint8_t value = storage_get(at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex);
192+
return THPByteUtils_newReal(value);
193+
/* Slice index */
194+
} else if (PySlice_Check(index)) {
195+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
196+
Py_ssize_t start, stop, slicelength, step;
197+
int64_t len = self->cdata->nbytes() / sizeof(uint8_t);
198+
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
199+
return nullptr;
200+
if (step != 1) {
201+
THPUtils_setError("Trying to slice with a step of %lld, but only a step of "
202+
"1 is supported", (long long)step);
203+
return nullptr;
204+
}
205+
206+
uint8_t *data = self->cdata->data<uint8_t>();
207+
208+
at::StorageImpl* old_storage = self->cdata;
209+
c10::raw::intrusive_ptr::incref(old_storage);
210+
auto new_storage = c10::make_intrusive<at::StorageImpl>(
211+
c10::StorageImpl::use_byte_size_t(),
212+
#ifdef THQUANTIZED
213+
slicelength * sizeof(quantized_t),
214+
#else
215+
slicelength * sizeof(uint8_t),
216+
#endif
217+
at::DataPtr(
218+
static_cast<void*>(data + start),
219+
old_storage,
220+
[](void* s) {
221+
c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
222+
},
223+
old_storage->device()),
224+
old_storage->allocator(),
225+
/* resizable */ false);
226+
227+
PyObject *_ret = THPStorage_New(std::move(new_storage));
228+
return _ret;
229+
}
230+
PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s",
231+
THPUtils_typename(index));
232+
return nullptr;
233+
END_HANDLE_TH_ERRORS
234+
}
235+
236+
static int THPStorage_set(THPStorage *self, PyObject *index, PyObject *value)
237+
{
238+
HANDLE_TH_ERRORS
239+
if (!THPByteUtils_checkReal(value)) {
240+
THPUtils_setError("can only set storage content with a %s, but got "
241+
"%s instead", THPUtils_typeTraits<uint8_t>::python_type_str,
242+
THPUtils_typename(value));
243+
return -1;
244+
}
245+
246+
uint8_t rvalue = THPByteUtils_unpackReal(value);
247+
if (THPUtils_checkLong(index)) {
248+
int64_t nindex = THPUtils_unpackLong(index);
249+
storage_set(
250+
at::unsafeStorageFromTH(self->cdata, /*retain=*/true),
251+
nindex,
252+
rvalue);
253+
return 0;
254+
} else if (PySlice_Check(index)) {
255+
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
256+
Py_ssize_t start, stop, slicelength, step;
257+
int64_t len = self->cdata->nbytes() / sizeof(uint8_t);
258+
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
259+
return -1;
260+
if (step != 1) {
261+
THPUtils_setError("Trying to slice with a step of %lld, but only a step of "
262+
"1 is supported", (long long)step);
263+
return 0;
264+
}
265+
// TODO: check the bounds only once
266+
// TODO: fill?
267+
for (;start < stop; start++)
268+
storage_set(
269+
at::unsafeStorageFromTH(self->cdata, /*retain=*/true),
270+
start,
271+
rvalue);
272+
return 0;
273+
}
274+
THPUtils_setError("can't index a " THPStorageStr " with %s",
275+
THPUtils_typename(index));
276+
return -1;
277+
END_HANDLE_TH_ERRORS_RET(-1)
278+
}
279+
280+
static PyMappingMethods THPStorage_mappingmethods = {
281+
(lenfunc)THPStorage_length,
282+
(binaryfunc)THPStorage_get,
283+
(objobjargproc)THPStorage_set
284+
};
285+
286+
// TODO: implement equality
287+
PyTypeObject THPStorageType = {
288+
PyVarObject_HEAD_INIT(nullptr, 0)
289+
"torch._C." THPStorageBaseStr, /* tp_name */
290+
sizeof(THPStorage), /* tp_basicsize */
291+
0, /* tp_itemsize */
292+
(destructor)THPStorage_dealloc, /* tp_dealloc */
293+
0, /* tp_vectorcall_offset */
294+
nullptr, /* tp_getattr */
295+
nullptr, /* tp_setattr */
296+
nullptr, /* tp_reserved */
297+
nullptr, /* tp_repr */
298+
nullptr, /* tp_as_number */
299+
nullptr, /* tp_as_sequence */
300+
&THPStorage_mappingmethods, /* tp_as_mapping */
301+
nullptr, /* tp_hash */
302+
nullptr, /* tp_call */
303+
nullptr, /* tp_str */
304+
nullptr, /* tp_getattro */
305+
nullptr, /* tp_setattro */
306+
nullptr, /* tp_as_buffer */
307+
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
308+
nullptr, /* tp_doc */
309+
nullptr, /* tp_traverse */
310+
nullptr, /* tp_clear */
311+
nullptr, /* tp_richcompare */
312+
0, /* tp_weaklistoffset */
313+
nullptr, /* tp_iter */
314+
nullptr, /* tp_iternext */
315+
nullptr, /* will be assigned in init */ /* tp_methods */
316+
nullptr, /* will be assigned in init */ /* tp_members */
317+
nullptr, /* tp_getset */
318+
nullptr, /* tp_base */
319+
nullptr, /* tp_dict */
320+
nullptr, /* tp_descr_get */
321+
nullptr, /* tp_descr_set */
322+
0, /* tp_dictoffset */
323+
nullptr, /* tp_init */
324+
nullptr, /* tp_alloc */
325+
THPStorage_pynew, /* tp_new */
326+
};
327+
328+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
329+
static struct PyMemberDef THPStorage_members[] = {
330+
{(char*)"_cdata", T_ULONGLONG, offsetof(THPStorage, cdata), READONLY, nullptr},
331+
{nullptr}
332+
};
333+
334+
static PyObject * THPStorage_device(THPStorage* self, void *unused) {
335+
HANDLE_TH_ERRORS
336+
return THPDevice_New(self->cdata->device());
337+
END_HANDLE_TH_ERRORS
338+
}
339+
340+
static PyObject * THPStorage_dtype(THPStorage *self, void *unused)
341+
{
342+
HANDLE_TH_ERRORS
343+
return torch::autograd::utils::wrap(
344+
torch::getTHPDtype(at::typeMetaToScalarType(
345+
#ifdef THQUANTIZED
346+
caffe2::TypeMeta::Make<quantized_t>()
347+
#else
348+
caffe2::TypeMeta::Make<uint8_t>()
349+
#endif
350+
)));
351+
END_HANDLE_TH_ERRORS
352+
}
353+
354+
typedef PyObject *(*getter)(PyObject *, void *);
355+
356+
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
357+
static struct PyGetSetDef THPStorage_properties[] = {
358+
{"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
359+
{nullptr}
360+
};
361+
362+
bool THPStorage_init(PyObject *module)
363+
{
364+
static std::vector<PyMethodDef> methods;
365+
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
366+
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
367+
368+
THPStorageType.tp_methods = methods.data();
369+
THPStorageType.tp_members = THPStorage_members;
370+
THPStorageType.tp_getset = THPStorage_properties;
371+
if (PyType_Ready(&THPStorageType) < 0)
372+
return false;
373+
Py_INCREF(&THPStorageType);
374+
PyModule_AddObject(module, THPStorageBaseStr, (PyObject *)&THPStorageType);
375+
return true;
376+
}
377+
378+
void THPStorage_postInit(PyObject *module)
379+
{
380+
THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage");
381+
if (!THPStorageClass) throw python_error();
382+
}

0 commit comments

Comments
 (0)