Skip to content

Commit 0d6eb20

Browse files
zou3519facebook-github-bot
authored andcommitted
Expose torch.empty(sizes, *, names, ...) to Python (pytorch#21648)
Summary: Pull Request resolved: pytorch#21648 ghimport-source-id: 583f155 Differential Revision: D15804482 Pulled By: zou3519 fbshipit-source-id: f86520dda479100be2a752e4db8a902167413a83
1 parent 7108218 commit 0d6eb20

11 files changed

+281
-4
lines changed

test/test_namedtensor.py

+89
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
22
from common_utils import TestCase, run_tests
3+
from common_cuda import TEST_CUDA
34
import torch
5+
import sys
46

57

68
def namedtensor_enabled():
@@ -10,11 +12,98 @@ def namedtensor_enabled():
1012
unittest.skipIf(not namedtensor_enabled(),
1113
'PyTorch not compiled with namedtensor support')
1214

15+
def pass_name_to_python_arg_parser(name):
16+
x = torch.empty(2, names=(name,))
17+
18+
1319
class TestNamedTensor(TestCase):
1420
@skipIfNamedTensorDisabled
1521
def test_trivial(self):
1622
pass
1723

24+
def _test_factory(self, factory, device):
25+
x = factory([], device=device)
26+
self.assertEqual(x.names, ())
27+
28+
x = factory(1, 2, 3, device=device)
29+
self.assertEqual(x.names, (None, None, None))
30+
31+
x = factory(1, 2, 3, names=None, device=device)
32+
self.assertEqual(x.names, (None, None, None))
33+
34+
x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device)
35+
self.assertEqual(x.names, ('N', 'T', 'D'))
36+
37+
x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
38+
self.assertEqual(x.names, ('N', None, 'D'))
39+
40+
with self.assertRaisesRegex(RuntimeError,
41+
'must contain alphabetical characters and/or underscore'):
42+
x = factory(2, names=('?',), device=device)
43+
44+
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
45+
x = factory(2, 1, names=('N',), device=device)
46+
47+
with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'):
48+
x = factory(2, 1, names='N', device=device)
49+
50+
51+
@skipIfNamedTensorDisabled
52+
def test_empty(self):
53+
self._test_factory(torch.empty, 'cpu')
54+
55+
@skipIfNamedTensorDisabled
56+
@unittest.skipIf(not TEST_CUDA, 'no CUDA')
57+
def test_empty_cuda(self):
58+
self._test_factory(torch.empty, 'cuda')
59+
60+
@skipIfNamedTensorDisabled
61+
def test_using_seen_interned_string_doesnt_bump_refcount(self):
62+
def see_name():
63+
seen_name = 'N'
64+
pass_name_to_python_arg_parser(seen_name)
65+
66+
see_name()
67+
seen_name = 'N'
68+
old_refcnt = sys.getrefcount(seen_name)
69+
70+
pass_name_to_python_arg_parser(seen_name)
71+
72+
new_refcnt = sys.getrefcount(seen_name)
73+
self.assertEqual(new_refcnt, old_refcnt)
74+
75+
@skipIfNamedTensorDisabled
76+
def test_using_unseen_interned_string_bumps_refcount_permanently(self):
77+
# Please don't use this as a name in a different test.
78+
unseen_name = 'abcdefghi'
79+
old_refcnt = sys.getrefcount(unseen_name)
80+
81+
pass_name_to_python_arg_parser(unseen_name)
82+
83+
new_refcnt = sys.getrefcount(unseen_name)
84+
self.assertEqual(new_refcnt, old_refcnt + 1)
85+
86+
@skipIfNamedTensorDisabled
87+
def test_using_unseen_uninterned_string_refcounts(self):
88+
# Please don't use this as a name in a different test.
89+
# non-compile-time constants are not interned
90+
unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl'])
91+
interned_unseen_name = 'abcdefghijkl'
92+
self.assertFalse(unseen_name is interned_unseen_name)
93+
94+
old_uninterned_refcnt = sys.getrefcount(unseen_name)
95+
old_interned_refcnt = sys.getrefcount(interned_unseen_name)
96+
97+
pass_name_to_python_arg_parser(unseen_name)
98+
99+
new_uninterned_refcnt = sys.getrefcount(unseen_name)
100+
new_interned_refcnt = sys.getrefcount(interned_unseen_name)
101+
102+
# Internally, PyTorch should not hold a reference to the uninterned string
103+
self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)
104+
105+
# Instead, we should hold a new reference to the interned version.
106+
self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)
18107

19108
if __name__ == '__main__':
20109
run_tests()

tools/autograd/gen_python_functions.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
4343
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
4444
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
45-
'empty(IntArrayRef, DimnameList?, TensorOptions)',
4645
]
4746

4847
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
@@ -163,7 +162,6 @@
163162
.pinned_memory(${pin_memory});
164163
""")
165164

166-
167165
def should_generate_python_binding(declaration):
168166
name = declaration['name']
169167
for pattern in SKIP_PYTHON_BINDINGS:
@@ -290,6 +288,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
290288
'const Type &': 'scalartype',
291289
'const THPLayout &': 'layout',
292290
'const Device &': 'device',
291+
'c10::optional<DimnameList>': 'toDimnameListOptional',
293292
'c10::optional<ScalarType>': 'scalartypeOptional',
294293
'c10::optional<Scalar>': 'scalarOptional',
295294
'c10::optional<int64_t>': 'toInt64Optional',
@@ -346,6 +345,19 @@ def get_type_args(args):
346345
if type_args and len(outputs) > 1:
347346
raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")
348347

348+
def unpack_variable(name, unpack_expr, typename):
349+
# optional<ArrayRef<T>> are special. The PythonArgParser returns an
350+
# optional<vector<T>>, which cannot be implictly converted to
351+
# optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
352+
if typename == 'c10::optional<DimnameList>':
353+
result = """\
354+
auto __{name} = {expr};
355+
c10::optional<{typ}> {name} = __{name} ? c10::make_optional({typ}(__{name}.value())) : c10::nullopt;
356+
""".format(name=name, expr=unpack_expr, typ='DimnameList')
357+
return [line.strip() for line in result.split('\n')]
358+
359+
return ['auto {} = {};'.format(name, unpack_expr)]
360+
349361
def parse_arg(arg, arg_index, unpack_args=False):
350362
name = arg['name']
351363
typename = arg['type']
@@ -365,7 +377,7 @@ def parse_arg(arg, arg_index, unpack_args=False):
365377
expr = 'r.{}({})'.format(unpack, arg_index)
366378

367379
if unpack_args:
368-
body.append('auto {} = {};'.format(name, expr))
380+
body.extend(unpack_variable(name, expr, typename))
369381
expr = name
370382

371383
dispatch_type = typename
@@ -633,6 +645,7 @@ def get_python_binding_arguments(declaration):
633645
'simple_type': 'bool',
634646
}
635647
python_binding_arguments.append(requires_grad_arg)
648+
636649
return python_binding_arguments
637650

638651
def emit_namedtuple_return_type_def(declaration, next_index):

tools/build_variables.py

+1
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def add_torch_libs():
170170
"torch/csrc/MemoryFormat.cpp",
171171
"torch/csrc/Module.cpp",
172172
"torch/csrc/PtrWrapper.cpp",
173+
"torch/csrc/python_dimname.cpp",
173174
"torch/csrc/Size.cpp",
174175
"torch/csrc/Storage.cpp",
175176
"torch/csrc/TypeInfo.cpp",

tools/pyi/gen_pyi.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def type_to_python(typename, size=None):
126126
'void*': '_int', # data_ptr
127127
'void': 'None',
128128
'std::string': 'str',
129+
'DimnameList': 'List[Union[str, None]]',
129130
}[typename]
130131

131132
return typename

torch/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ set(TORCH_PYTHON_SRCS
5151
${TORCH_SRC_DIR}/csrc/Generator.cpp
5252
${TORCH_SRC_DIR}/csrc/Layout.cpp
5353
${TORCH_SRC_DIR}/csrc/MemoryFormat.cpp
54+
${TORCH_SRC_DIR}/csrc/python_dimname.cpp
5455
${TORCH_SRC_DIR}/csrc/Module.cpp
5556
${TORCH_SRC_DIR}/csrc/PtrWrapper.cpp
5657
${TORCH_SRC_DIR}/csrc/Size.cpp

torch/csrc/autograd/python_variable.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,37 @@ PyObject *THPVariable_get_ndim(THPVariable *self)
311311
END_HANDLE_TH_ERRORS
312312
}
313313

314+
#ifdef NAMEDTENSOR_ENABLED
315+
PyObject *THPVariable_get_names(THPVariable *self)
316+
{
317+
HANDLE_TH_ERRORS
318+
// The long-term plan is to return a list of (python) torch.Dimname.
319+
// However, for now, return a list of string.
320+
size_t size = self->cdata.dim();
321+
THPObjectPtr tuple(PyTuple_New(size));
322+
if (!tuple) throw python_error();
323+
324+
if (!self->cdata.is_named()) {
325+
for (size_t i = 0; i < size; ++i) {
326+
PyTuple_SET_ITEM(tuple.get(), i, Py_None);
327+
}
328+
return tuple.release();
329+
}
330+
331+
const auto dimnames = self->cdata.names().value();
332+
for (size_t i = 0; i < size; ++i) {
333+
PyObject* str = Py_None;
334+
if (dimnames[i].type() != at::NameType::WILDCARD) {
335+
str = THPUtils_packString(dimnames[i].name().toUnqualString());
336+
if (!str) throw python_error();
337+
}
338+
PyTuple_SET_ITEM(tuple.get(), i, str);
339+
}
340+
return tuple.release();
341+
END_HANDLE_TH_ERRORS
342+
}
343+
#endif
344+
314345
int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj)
315346
{
316347
HANDLE_TH_ERRORS
@@ -452,6 +483,9 @@ static struct PyGetSetDef THPVariable_properties[] = {
452483
{"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
453484
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
454485
{"ndim", (getter)THPVariable_get_ndim, nullptr, nullptr, nullptr},
486+
#ifdef NAMEDTENSOR_ENABLED
487+
{"names", (getter)THPVariable_get_names, nullptr, nullptr, nullptr},
488+
#endif
455489
{nullptr}
456490
};
457491

torch/csrc/python_dimname.cpp

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#ifdef NAMEDTENSOR_ENABLED
2+
#include <torch/csrc/python_dimname.h>
3+
#include <torch/csrc/Exceptions.h>
4+
#include <torch/csrc/utils/python_strings.h>
5+
#include <c10/util/flat_hash_map.h>
6+
7+
namespace torch {
8+
9+
struct InternedStringsTable {
10+
InternedStringsTable() = default;
11+
~InternedStringsTable();
12+
InternedStringsTable(const InternedStringsTable &) = delete;
13+
InternedStringsTable& operator =(InternedStringsTable const&) = delete;
14+
InternedStringsTable(InternedStringsTable&&) = delete;
15+
InternedStringsTable& operator=(InternedStringsTable&&) = delete;
16+
17+
at::optional<at::Dimname> lookup(PyObject* obj);
18+
// Precondition: obj is an interned python string.
19+
void addMapping(PyObject* obj, at::Dimname dimname);
20+
private:
21+
ska::flat_hash_map<PyObject*,at::Dimname> py_interned_string_to_dimname_;
22+
};
23+
24+
InternedStringsTable kPyInternedStringToDimname;
25+
26+
InternedStringsTable::~InternedStringsTable() {
27+
for (auto it = py_interned_string_to_dimname_.begin();
28+
it != py_interned_string_to_dimname_.end(); ++it) {
29+
// See Note [References to python interned strings]
30+
Py_DECREF(it->first);
31+
}
32+
}
33+
34+
at::optional<at::Dimname> InternedStringsTable::lookup(PyObject* obj) {
35+
auto it = py_interned_string_to_dimname_.find(obj);
36+
if (it == py_interned_string_to_dimname_.end()) {
37+
return at::nullopt;
38+
}
39+
return it->second;
40+
}
41+
42+
43+
void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) {
44+
// Note [References to python interned strings]
45+
// If a Python interned string has no references to it, then it gets
46+
// deallocated, invalidating this mapping. Let's immortalize the string by
47+
// holding a refcount to it and releasing it in the destructor
48+
Py_INCREF(obj);
49+
py_interned_string_to_dimname_.emplace(obj, dimname);
50+
}
51+
52+
} // namespace torch
53+
54+
at::Dimname THPDimname_parse(PyObject* obj) {
55+
if (obj == Py_None) {
56+
return at::Dimname::wildcard();
57+
}
58+
59+
if (!THPUtils_checkString(obj)) {
60+
throw torch::TypeError("expected None or string for Dimname but got %s", Py_TYPE(obj)->tp_name);
61+
}
62+
63+
if (!THPUtils_isInterned(obj)) {
64+
// internStringInPlace decrefs obj and increfs the result. Because we're
65+
// not actually returning the result to the user, we need to undo these.
66+
// See https://docs.python.org/3/c-api/unicode.html#c.PyUnicode_InternInPlace
67+
Py_INCREF(obj);
68+
THPUtils_internStringInPlace(&obj);
69+
Py_DECREF(obj);
70+
}
71+
72+
auto maybeDimname = torch::kPyInternedStringToDimname.lookup(obj);
73+
if (maybeDimname) {
74+
return *maybeDimname;
75+
}
76+
77+
const auto name = THPUtils_unpackString(obj);
78+
auto dimname = at::Dimname::fromSymbol(at::Symbol::dimname(name));
79+
torch::kPyInternedStringToDimname.addMapping(obj, dimname);
80+
return dimname;
81+
}
82+
83+
#endif

torch/csrc/python_dimname.h

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
#ifdef NAMEDTENSOR_ENABLED
3+
#include <torch/csrc/python_headers.h>
4+
#include <ATen/Dimname.h>
5+
6+
at::Dimname THPDimname_parse(PyObject* obj);
7+
8+
#endif

torch/csrc/utils/python_arg_parser.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ static std::unordered_map<std::string, ParameterType> type_map = {
3232
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
3333
{"Device", ParameterType::DEVICE},
3434
{"std::string", ParameterType::STRING},
35+
{"Dimname", ParameterType::DIMNAME},
36+
{"DimnameList", ParameterType::DIMNAME_LIST},
3537
};
3638

3739
// Default arg name translations for compatibility with NumPy.
@@ -157,6 +159,7 @@ bool FunctionParameter::check(PyObject* obj) {
157159
}
158160
return false;
159161
}
162+
case ParameterType::DIMNAME_LIST:
160163
case ParameterType::TENSOR_LIST: return six::isTuple(obj) || PyList_Check(obj);
161164
case ParameterType::INT_LIST: {
162165
if (PyTuple_Check(obj) || PyList_Check(obj)) {
@@ -196,6 +199,9 @@ std::string FunctionParameter::type_name() const {
196199
case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
197200
case ParameterType::DEVICE: return "torch.device";
198201
case ParameterType::STRING: return "str";
202+
#ifdef NAMEDTENSOR_ENABLED
203+
case ParameterType::DIMNAME_LIST: return "tuple of names";
204+
#endif
199205
default: throw std::runtime_error("unknown parameter type");
200206
}
201207
}

0 commit comments

Comments
 (0)