From 0235533fdaace87738c782278b4c64cd5fe9e128 Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Sat, 3 Aug 2024 10:06:13 +0200 Subject: [PATCH 1/3] add argument number dispatch mechanism for std::function casting --- include/pybind11/functional.h | 41 ++++++++++++++++++++++++++- tests/test_callbacks.cpp | 6 ++++ tests/test_callbacks.py | 19 ++++++++++++- tests/test_embed/test_interpreter.cpp | 16 +++++++++++ 4 files changed, 80 insertions(+), 2 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 4b3610117c..fde99a4ae2 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -101,8 +101,17 @@ struct type_caster> { if (detail::is_function_record_capsule(c)) { rec = c.get_pointer(); } - while (rec != nullptr) { + const int correctingSelfArgument = rec->is_method ? 1 : 0; + if (rec->nargs - correctingSelfArgument != sizeof...(Args)) { + rec = rec->next; + // if the overload is not feasible in terms of number of arguments, we + // continue to the next one. If there is no next one, we return false. + if (rec == nullptr) { + return false; + } + continue; + } if (rec->is_stateless && same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { @@ -118,6 +127,36 @@ struct type_caster> { // PYPY segfaults here when passing builtin function like sum. // Raising an fail exception here works to prevent the segfault, but only on gcc. // See PR #1413 for full details + } else { + // Check number of arguments of Python function + auto getArgCount = [&](PyObject *obj) { + // This is faster then doing import inspect and inspect.signature(obj).parameters + auto *t = PyObject_GetAttrString(obj, "__code__"); + auto *argCount = PyObject_GetAttrString(t, "co_argcount"); + return PyLong_AsLong(argCount); + }; + long argCount = -1; + + if (static_cast(PyObject_HasAttrString(src.ptr(), "__code__"))) { + argCount = getArgCount(src.ptr()); + } else { + if (static_cast(PyObject_HasAttrString(src.ptr(), "__call__"))) { + auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__"); + argCount = getArgCount(t2) - 1; // we have to remove the self argument + } else { + // No __code__ or __call__ attribute, this is not a proper Python function + return false; + } + } + // if we are a method, we have to correct the argument count since we are not counting + // the self argument + const int correctingSelfArgument + = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; + + argCount -= correctingSelfArgument; + if (argCount != sizeof...(Args)) { + return false; + } } value = type_caster_std_function_specializations::func_wrapper( diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index 2fd05dec72..ed55ad7b7e 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) { return "argument does NOT match dummy_function. This should never happen!"; }); + // test_cpp_correct_overload_resolution + m.def("dummy_function_overloaded_std_func_arg", + [](const std::function &f) { return 3 * f(3); }); + m.def("dummy_function_overloaded_std_func_arg", + [](const std::function &f) { return 2 * f(3, 4); }); + class AbstractBase { public: // [workaround(intel)] = default does not work here diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index db6d8dece0..82b03fac1f 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -103,6 +103,20 @@ def test_cpp_callable_cleanup(): assert alive_counts == [0, 1, 2, 1, 2, 1, 0] +def test_cpp_correct_overload_resolution(): + def f(a): + return a + + assert m.dummy_function_overloaded_std_func_arg(f) == 9 + assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 + + def f2(a, b): + return a + b + + assert m.dummy_function_overloaded_std_func_arg(f2) == 14 + assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14 + + def test_cpp_function_roundtrip(): """Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer""" @@ -131,7 +145,10 @@ def test_cpp_function_roundtrip(): m.test_dummy_function(lambda x, y: x + y) assert any( s in str(excinfo.value) - for s in ("missing 1 required positional argument", "takes exactly 2 arguments") + for s in ( + "incompatible function arguments. The following argument types are", + "function test_cpp_function_roundtrip..", + ) ) diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index c6c8a22d98..98df9b19e6 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -1,4 +1,5 @@ #include +#include // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to // catch 2.0.1; this should be fixed in the next catch release after 2.0.1). @@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { d["missing"].cast(); } +PYBIND11_EMBEDDED_MODULE(func_module, m) { + m.def("funcOverload", [](const std::function &f) { + return f(2, 3); + }).def("funcOverload", [](const std::function &f) { return f(2); }); +} + TEST_CASE("PYTHONPATH is used to update sys.path") { // The setup for this TEST_CASE is in catch.cpp! auto sys_path = py::str(py::module_::import("sys").attr("path")).cast(); @@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") { py::initialize_interpreter(); } +TEST_CASE("Check the overload resolution from cpp_function objects to std::function") { + auto m = py::module_::import("func_module"); + auto f = std::function([](int x) { return 2 * x; }); + REQUIRE(m.attr("funcOverload")(f).template cast() == 4); + + auto f2 = std::function([](int x, int y) { return 2 * x * y; }); + REQUIRE(m.attr("funcOverload")(f2).template cast() == 12); +} + #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX TEST_CASE("Custom PyConfig") { py::finalize_interpreter(); From d21cee39e8d4870fd4f9048145bec41f37b49fdd Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Tue, 20 Aug 2024 11:49:16 +0200 Subject: [PATCH 2/3] changes from review --- include/pybind11/functional.h | 35 +++++++++++++++++++---------------- tests/test_callbacks.py | 10 ++++++++++ 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index fde99a4ae2..8a8c32c0ec 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -102,8 +102,8 @@ struct type_caster> { rec = c.get_pointer(); } while (rec != nullptr) { - const int correctingSelfArgument = rec->is_method ? 1 : 0; - if (rec->nargs - correctingSelfArgument != sizeof...(Args)) { + const size_t self_offset = rec->is_method ? 1 : 0; + if (rec->nargs != sizeof...(Args) + self_offset) { rec = rec->next; // if the overload is not feasible in terms of number of arguments, we // continue to the next one. If there is no next one, we return false. @@ -129,20 +129,24 @@ struct type_caster> { // See PR #1413 for full details } else { // Check number of arguments of Python function - auto getArgCount = [&](PyObject *obj) { - // This is faster then doing import inspect and inspect.signature(obj).parameters - auto *t = PyObject_GetAttrString(obj, "__code__"); - auto *argCount = PyObject_GetAttrString(t, "co_argcount"); - return PyLong_AsLong(argCount); + auto argCountFromFuncCode = [&](handle &obj) { + // This is faster then doing import inspect and + // inspect.signature(obj).parameters + + object argCount = obj.attr("co_argcount"); + return argCount.template cast(); }; - long argCount = -1; + size_t argCount = 0; - if (static_cast(PyObject_HasAttrString(src.ptr(), "__code__"))) { - argCount = getArgCount(src.ptr()); + handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); + if (codeAttr) { + argCount = argCountFromFuncCode(codeAttr); } else { - if (static_cast(PyObject_HasAttrString(src.ptr(), "__call__"))) { - auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__"); - argCount = getArgCount(t2) - 1; // we have to remove the self argument + handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); + if (callAttr) { + handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__"); + argCount = argCountFromFuncCode(codeAttr2) + - 1; // we have to remove the self argument } else { // No __code__ or __call__ attribute, this is not a proper Python function return false; @@ -150,10 +154,9 @@ struct type_caster> { } // if we are a method, we have to correct the argument count since we are not counting // the self argument - const int correctingSelfArgument - = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; + const size_t self_offset = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; - argCount -= correctingSelfArgument; + argCount -= self_offset; if (argCount != sizeof...(Args)) { return false; } diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 82b03fac1f..d2afbc2ca3 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -107,13 +107,23 @@ def test_cpp_correct_overload_resolution(): def f(a): return a + class A: + def __call__(self, a): + return a + assert m.dummy_function_overloaded_std_func_arg(f) == 9 + assert m.dummy_function_overloaded_std_func_arg(A()) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 def f2(a, b): return a + b + class B: + def __call__(self, a, b): + return a + b + assert m.dummy_function_overloaded_std_func_arg(f2) == 14 + assert m.dummy_function_overloaded_std_func_arg(B()) == 14 assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14 From e0be5dbd48e414f9a2bfcd4a9e4729edc7af3bd9 Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Tue, 20 Aug 2024 17:09:24 +0200 Subject: [PATCH 3/3] test fix --- include/pybind11/functional.h | 24 ++++++++++++------------ tests/test_callbacks.py | 3 ++- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 8a8c32c0ec..4baeaa57a2 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -14,6 +14,7 @@ #include "pybind11.h" #include +#include PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -129,24 +130,23 @@ struct type_caster> { // See PR #1413 for full details } else { // Check number of arguments of Python function - auto argCountFromFuncCode = [&](handle &obj) { - // This is faster then doing import inspect and - // inspect.signature(obj).parameters - - object argCount = obj.attr("co_argcount"); - return argCount.template cast(); + auto get_argument_count = [](const handle &obj) -> size_t { + // Faster then `import inspect` and `inspect.signature(obj).parameters` + return obj.attr("co_argcount").cast(); }; size_t argCount = 0; - handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); + handle empty; + object codeAttr = getattr(src, "__code__", empty); + if (codeAttr) { - argCount = argCountFromFuncCode(codeAttr); + argCount = get_argument_count(codeAttr); } else { - handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); + object callAttr = getattr(src, "__call__", empty); + if (callAttr) { - handle codeAttr2 = PyObject_GetAttrString(callAttr.ptr(), "__code__"); - argCount = argCountFromFuncCode(codeAttr2) - - 1; // we have to remove the self argument + object codeAttr2 = getattr(callAttr, "__code__"); + argCount = get_argument_count(codeAttr2) - 1; // removing the self argument } else { // No __code__ or __call__ attribute, this is not a proper Python function return false; diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d2afbc2ca3..c81aee6672 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -112,7 +112,8 @@ def __call__(self, a): return a assert m.dummy_function_overloaded_std_func_arg(f) == 9 - assert m.dummy_function_overloaded_std_func_arg(A()) == 9 + a = A() + assert m.dummy_function_overloaded_std_func_arg(a) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 def f2(a, b):