Skip to content

Commit 49a6f22

Browse files
committed
Fix buffer protocol implementation
According to the buffer protocol, `ndim` is a _required_ field [1], and should always be set correctly. Additionally, `shape` should be set if flags includes `PyBUF_ND` or higher [2]. The current implementation only set those fields if flags was `PyBUF_STRIDES`. [1] https://docs.python.org/3/c-api/buffer.html#request-independent-fields [2] https://docs.python.org/3/c-api/buffer.html#shape-strides-suboffsets
1 parent af67e87 commit 49a6f22

File tree

3 files changed

+95
-3
lines changed

3 files changed

+95
-3
lines changed

include/pybind11/detail/class.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,9 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
602602
return -1;
603603
}
604604
view->obj = obj;
605-
view->ndim = 1;
606605
view->internal = info;
607606
view->buf = info->ptr;
607+
view->ndim = (int) info->ndim;
608608
view->itemsize = info->itemsize;
609609
view->len = view->itemsize;
610610
for (auto s : info->shape) {
@@ -614,10 +614,11 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
614614
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
615615
view->format = const_cast<char *>(info->format.c_str());
616616
}
617+
if ((flags & PyBUF_ND) == PyBUF_ND) {
618+
view->shape = info->shape.data();
619+
}
617620
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
618-
view->ndim = (int) info->ndim;
619621
view->strides = info->strides.data();
620-
view->shape = info->shape.data();
621622
}
622623
Py_INCREF(view->obj);
623624
return 0;

tests/test_buffers.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,4 +268,58 @@ TEST_SUBMODULE(buffers, m) {
268268
});
269269

270270
m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); });
271+
272+
// Expose Py_buffer for testing.
273+
py::class_<Py_buffer>(m, "Py_buffer")
274+
.def_readonly("len", &Py_buffer::len)
275+
.def_readonly("readonly", &Py_buffer::readonly)
276+
.def_readonly("itemsize", &Py_buffer::itemsize)
277+
.def_readonly("format", &Py_buffer::format)
278+
.def_readonly("ndim", &Py_buffer::ndim)
279+
.def_property_readonly("shape",
280+
[](const Py_buffer &buffer) -> py::object {
281+
if (buffer.shape == nullptr) {
282+
return py::none();
283+
}
284+
py::list l;
285+
for (auto i = 0; i < buffer.ndim; i++) {
286+
l.append(buffer.shape[i]);
287+
}
288+
return l;
289+
})
290+
.def_property_readonly("strides",
291+
[](const Py_buffer &buffer) -> py::object {
292+
if (buffer.strides == nullptr) {
293+
return py::none();
294+
}
295+
py::list l;
296+
for (auto i = 0; i < buffer.ndim; i++) {
297+
l.append(buffer.strides[i]);
298+
}
299+
return l;
300+
})
301+
.def_property_readonly("suboffsets", [](const Py_buffer &buffer) -> py::object {
302+
if (buffer.suboffsets == nullptr) {
303+
return py::none();
304+
}
305+
py::list l;
306+
for (auto i = 0; i < buffer.ndim; i++) {
307+
l.append(buffer.suboffsets[i]);
308+
}
309+
return l;
310+
});
311+
m.attr("PyBUF_SIMPLE") = PyBUF_SIMPLE;
312+
m.attr("PyBUF_ND") = PyBUF_ND;
313+
m.attr("PyBUF_STRIDES") = PyBUF_STRIDES;
314+
m.attr("PyBUF_INDIRECT") = PyBUF_INDIRECT;
315+
316+
m.def("get_py_buffer", [](const py::object &object, int flags) {
317+
Py_buffer buffer;
318+
memset(&buffer, 0, sizeof(Py_buffer));
319+
if (PyObject_GetBuffer(object.ptr(), &buffer, flags) == -1) {
320+
throw py::error_already_set();
321+
}
322+
// TODO: This leaks...
323+
return buffer;
324+
});
271325
}

tests/test_buffers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,40 @@ def test_buffer_exception():
239239
memoryview(m.BrokenMatrix(1, 1))
240240
assert isinstance(excinfo.value.__cause__, RuntimeError)
241241
assert "for context" in str(excinfo.value.__cause__)
242+
243+
244+
def test_to_pybuffer():
245+
mat = m.Matrix(5, 4)
246+
247+
info = m.get_py_buffer(mat, m.PyBUF_SIMPLE)
248+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
249+
assert info.len == mat.rows() * mat.cols() * info.itemsize
250+
assert info.ndim == 2
251+
assert info.shape is None
252+
assert info.strides is None
253+
assert info.suboffsets is None
254+
assert not info.readonly
255+
info = m.get_py_buffer(mat, m.PyBUF_ND)
256+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
257+
assert info.len == mat.rows() * mat.cols() * info.itemsize
258+
assert info.ndim == 2
259+
assert info.shape == [5, 4]
260+
assert info.strides is None
261+
assert info.suboffsets is None
262+
assert not info.readonly
263+
info = m.get_py_buffer(mat, m.PyBUF_STRIDES)
264+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
265+
assert info.len == mat.rows() * mat.cols() * info.itemsize
266+
assert info.ndim == 2
267+
assert info.shape == [5, 4]
268+
assert info.strides == [4 * info.itemsize, info.itemsize]
269+
assert info.suboffsets is None
270+
assert not info.readonly
271+
info = m.get_py_buffer(mat, m.PyBUF_INDIRECT)
272+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
273+
assert info.len == mat.rows() * mat.cols() * info.itemsize
274+
assert info.ndim == 2
275+
assert info.shape == [5, 4]
276+
assert info.strides == [4 * info.itemsize, info.itemsize]
277+
assert info.suboffsets is None # Should be filled in here, but we don't use it.
278+
assert not info.readonly

0 commit comments

Comments
 (0)