diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 4b3610117c..4baeaa57a2 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -14,6 +14,7 @@ #include "pybind11.h" #include +#include PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) PYBIND11_NAMESPACE_BEGIN(detail) @@ -101,8 +102,17 @@ struct type_caster> { if (detail::is_function_record_capsule(c)) { rec = c.get_pointer(); } - while (rec != nullptr) { + const size_t self_offset = rec->is_method ? 1 : 0; + if (rec->nargs != sizeof...(Args) + self_offset) { + rec = rec->next; + // if the overload is not feasible in terms of number of arguments, we + // continue to the next one. If there is no next one, we return false. + if (rec == nullptr) { + return false; + } + continue; + } if (rec->is_stateless && same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { @@ -118,6 +128,38 @@ struct type_caster> { // PYPY segfaults here when passing builtin function like sum. // Raising an fail exception here works to prevent the segfault, but only on gcc. // See PR #1413 for full details + } else { + // Check number of arguments of Python function + auto get_argument_count = [](const handle &obj) -> size_t { + // Faster then `import inspect` and `inspect.signature(obj).parameters` + return obj.attr("co_argcount").cast(); + }; + size_t argCount = 0; + + handle empty; + object codeAttr = getattr(src, "__code__", empty); + + if (codeAttr) { + argCount = get_argument_count(codeAttr); + } else { + object callAttr = getattr(src, "__call__", empty); + + if (callAttr) { + object codeAttr2 = getattr(callAttr, "__code__"); + argCount = get_argument_count(codeAttr2) - 1; // removing the self argument + } else { + // No __code__ or __call__ attribute, this is not a proper Python function + return false; + } + } + // if we are a method, we have to correct the argument count since we are not counting + // the self argument + const size_t self_offset = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; + + argCount -= self_offset; + if (argCount != sizeof...(Args)) { + return false; + } } value = type_caster_std_function_specializations::func_wrapper( diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp index 2fd05dec72..ed55ad7b7e 100644 --- a/tests/test_callbacks.cpp +++ b/tests/test_callbacks.cpp @@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) { return "argument does NOT match dummy_function. This should never happen!"; }); + // test_cpp_correct_overload_resolution + m.def("dummy_function_overloaded_std_func_arg", + [](const std::function &f) { return 3 * f(3); }); + m.def("dummy_function_overloaded_std_func_arg", + [](const std::function &f) { return 2 * f(3, 4); }); + class AbstractBase { public: // [workaround(intel)] = default does not work here diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index db6d8dece0..c81aee6672 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -103,6 +103,31 @@ def test_cpp_callable_cleanup(): assert alive_counts == [0, 1, 2, 1, 2, 1, 0] +def test_cpp_correct_overload_resolution(): + def f(a): + return a + + class A: + def __call__(self, a): + return a + + assert m.dummy_function_overloaded_std_func_arg(f) == 9 + a = A() + assert m.dummy_function_overloaded_std_func_arg(a) == 9 + assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 + + def f2(a, b): + return a + b + + class B: + def __call__(self, a, b): + return a + b + + assert m.dummy_function_overloaded_std_func_arg(f2) == 14 + assert m.dummy_function_overloaded_std_func_arg(B()) == 14 + assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14 + + def test_cpp_function_roundtrip(): """Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer""" @@ -131,7 +156,10 @@ def test_cpp_function_roundtrip(): m.test_dummy_function(lambda x, y: x + y) assert any( s in str(excinfo.value) - for s in ("missing 1 required positional argument", "takes exactly 2 arguments") + for s in ( + "incompatible function arguments. The following argument types are", + "function test_cpp_function_roundtrip..", + ) ) diff --git a/tests/test_embed/test_interpreter.cpp b/tests/test_embed/test_interpreter.cpp index c6c8a22d98..98df9b19e6 100644 --- a/tests/test_embed/test_interpreter.cpp +++ b/tests/test_embed/test_interpreter.cpp @@ -1,4 +1,5 @@ #include +#include // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to // catch 2.0.1; this should be fixed in the next catch release after 2.0.1). @@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { d["missing"].cast(); } +PYBIND11_EMBEDDED_MODULE(func_module, m) { + m.def("funcOverload", [](const std::function &f) { + return f(2, 3); + }).def("funcOverload", [](const std::function &f) { return f(2); }); +} + TEST_CASE("PYTHONPATH is used to update sys.path") { // The setup for this TEST_CASE is in catch.cpp! auto sys_path = py::str(py::module_::import("sys").attr("path")).cast(); @@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") { py::initialize_interpreter(); } +TEST_CASE("Check the overload resolution from cpp_function objects to std::function") { + auto m = py::module_::import("func_module"); + auto f = std::function([](int x) { return 2 * x; }); + REQUIRE(m.attr("funcOverload")(f).template cast() == 4); + + auto f2 = std::function([](int x, int y) { return 2 * x * y; }); + REQUIRE(m.attr("funcOverload")(f2).template cast() == 12); +} + #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX TEST_CASE("Custom PyConfig") { py::finalize_interpreter();