Skip to content

Commit 9f36b54

Browse files
committed
Fix buffer protocol implementation
According to the buffer protocol, `ndim` is a _required_ field [1], and should always be set correctly. Additionally, `shape` should be set if flags includes `PyBUF_ND` or higher [2]. The current implementation only set those fields if flags was `PyBUF_STRIDES`. [1] https://docs.python.org/3/c-api/buffer.html#request-independent-fields [2] https://docs.python.org/3/c-api/buffer.html#shape-strides-suboffsets
1 parent f7e14e9 commit 9f36b54

File tree

3 files changed

+89
-3
lines changed

3 files changed

+89
-3
lines changed

Diff for: include/pybind11/detail/class.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,9 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
602602
return -1;
603603
}
604604
view->obj = obj;
605-
view->ndim = 1;
606605
view->internal = info;
607606
view->buf = info->ptr;
607+
view->ndim = (int) info->ndim;
608608
view->itemsize = info->itemsize;
609609
view->len = view->itemsize;
610610
for (auto s : info->shape) {
@@ -614,10 +614,11 @@ extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int fla
614614
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
615615
view->format = const_cast<char *>(info->format.c_str());
616616
}
617+
if ((flags & PyBUF_ND) == PyBUF_ND) {
618+
view->shape = info->shape.data();
619+
}
617620
if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
618-
view->ndim = (int) info->ndim;
619621
view->strides = info->strides.data();
620-
view->shape = info->shape.data();
621622
}
622623
Py_INCREF(view->obj);
623624
return 0;

Diff for: tests/test_buffers.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -268,4 +268,52 @@ TEST_SUBMODULE(buffers, m) {
268268
});
269269

270270
m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); });
271+
272+
// Expose Py_buffer for testing.
273+
m.attr("PyBUF_SIMPLE") = PyBUF_SIMPLE;
274+
m.attr("PyBUF_ND") = PyBUF_ND;
275+
m.attr("PyBUF_STRIDES") = PyBUF_STRIDES;
276+
m.attr("PyBUF_INDIRECT") = PyBUF_INDIRECT;
277+
278+
m.def("get_py_buffer", [](const py::object &object, int flags) {
279+
Py_buffer buffer;
280+
memset(&buffer, 0, sizeof(Py_buffer));
281+
if (PyObject_GetBuffer(object.ptr(), &buffer, flags) == -1) {
282+
throw py::error_already_set();
283+
}
284+
285+
auto SimpleNamespace = py::module_::import("types").attr("SimpleNamespace");
286+
py::object result = SimpleNamespace("len"_a = buffer.len,
287+
"readonly"_a = buffer.readonly,
288+
"itemsize"_a = buffer.itemsize,
289+
"format"_a = buffer.format,
290+
"ndim"_a = buffer.ndim,
291+
"shape"_a = py::none(),
292+
"strides"_a = py::none(),
293+
"suboffsets"_a = py::none());
294+
if (buffer.shape != nullptr) {
295+
py::list l;
296+
for (auto i = 0; i < buffer.ndim; i++) {
297+
l.append(buffer.shape[i]);
298+
}
299+
py::setattr(result, "shape", l);
300+
}
301+
if (buffer.strides != nullptr) {
302+
py::list l;
303+
for (auto i = 0; i < buffer.ndim; i++) {
304+
l.append(buffer.strides[i]);
305+
}
306+
py::setattr(result, "strides", l);
307+
}
308+
if (buffer.suboffsets != nullptr) {
309+
py::list l;
310+
for (auto i = 0; i < buffer.ndim; i++) {
311+
l.append(buffer.suboffsets[i]);
312+
}
313+
py::setattr(result, "suboffsets", l);
314+
}
315+
316+
PyBuffer_Release(&buffer);
317+
return result;
318+
});
271319
}

Diff for: tests/test_buffers.py

+37
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,40 @@ def test_buffer_exception():
239239
memoryview(m.BrokenMatrix(1, 1))
240240
assert isinstance(excinfo.value.__cause__, RuntimeError)
241241
assert "for context" in str(excinfo.value.__cause__)
242+
243+
244+
def test_to_pybuffer():
245+
mat = m.Matrix(5, 4)
246+
247+
info = m.get_py_buffer(mat, m.PyBUF_SIMPLE)
248+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
249+
assert info.len == mat.rows() * mat.cols() * info.itemsize
250+
assert info.ndim == 2
251+
assert info.shape is None
252+
assert info.strides is None
253+
assert info.suboffsets is None
254+
assert not info.readonly
255+
info = m.get_py_buffer(mat, m.PyBUF_ND)
256+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
257+
assert info.len == mat.rows() * mat.cols() * info.itemsize
258+
assert info.ndim == 2
259+
assert info.shape == [5, 4]
260+
assert info.strides is None
261+
assert info.suboffsets is None
262+
assert not info.readonly
263+
info = m.get_py_buffer(mat, m.PyBUF_STRIDES)
264+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
265+
assert info.len == mat.rows() * mat.cols() * info.itemsize
266+
assert info.ndim == 2
267+
assert info.shape == [5, 4]
268+
assert info.strides == [4 * info.itemsize, info.itemsize]
269+
assert info.suboffsets is None
270+
assert not info.readonly
271+
info = m.get_py_buffer(mat, m.PyBUF_INDIRECT)
272+
assert info.itemsize == ctypes.sizeof(ctypes.c_float)
273+
assert info.len == mat.rows() * mat.cols() * info.itemsize
274+
assert info.ndim == 2
275+
assert info.shape == [5, 4]
276+
assert info.strides == [4 * info.itemsize, info.itemsize]
277+
assert info.suboffsets is None # Should be filled in here, but we don't use it.
278+
assert not info.readonly

0 commit comments

Comments
 (0)