Skip to content

Commit 87b9bb2

Browse files
author
AlexanderMueller
committed
add argument number dispatch mechanism for std::function casting
1 parent 916778d commit 87b9bb2

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

include/pybind11/functional.h

+40-1
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,17 @@ struct type_caster<std::function<Return(Args...)>> {
5858
if (detail::is_function_record_capsule(c)) {
5959
rec = c.get_pointer<function_record>();
6060
}
61-
6261
while (rec != nullptr) {
62+
const int correctingSelfArgument = rec->is_method ? 1 : 0;
63+
if (rec->nargs - correctingSelfArgument != sizeof...(Args)) {
64+
rec = rec->next;
65+
// if the overload is not feasible in terms of number of arguments, we
66+
// continue to the next one. If there is no next one, we return false.
67+
if (rec == nullptr) {
68+
return false;
69+
}
70+
continue;
71+
}
6372
if (rec->is_stateless
6473
&& same_type(typeid(function_type),
6574
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
@@ -75,6 +84,36 @@ struct type_caster<std::function<Return(Args...)>> {
7584
// PYPY segfaults here when passing builtin function like sum.
7685
// Raising an fail exception here works to prevent the segfault, but only on gcc.
7786
// See PR #1413 for full details
87+
} else {
88+
// Check number of arguments of Python function
89+
auto getArgCount = [&](PyObject *obj) {
90+
// This is faster then doing import inspect and inspect.signature(obj).parameters
91+
auto *t = PyObject_GetAttrString(obj, "__code__");
92+
auto *argCount = PyObject_GetAttrString(t, "co_argcount");
93+
return PyLong_AsLong(argCount);
94+
};
95+
long argCount = -1;
96+
97+
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__code__"))) {
98+
argCount = getArgCount(src.ptr());
99+
} else {
100+
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__call__"))) {
101+
auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__");
102+
argCount = getArgCount(t2) - 1; // we have to remove the self argument
103+
} else {
104+
// No __code__ or __call__ attribute, this is not a proper Python function
105+
return false;
106+
}
107+
}
108+
// if we are a method, we have to correct the argument count since we are not counting
109+
// the self argument
110+
const int correctingSelfArgument
111+
= static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
112+
113+
argCount -= correctingSelfArgument;
114+
if (argCount != sizeof...(Args)) {
115+
return false;
116+
}
78117
}
79118

80119
// ensure GIL is held during functor destruction

tests/test_callbacks.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) {
170170
return "argument does NOT match dummy_function. This should never happen!";
171171
});
172172

173+
// test_cpp_correct_overload_resolution
174+
m.def("dummy_function_overloaded_std_func_arg",
175+
[](std::function<int(int)> f) { return 3 * f(3); });
176+
m.def("dummy_function_overloaded_std_func_arg",
177+
[](std::function<int(int, int)> f) { return 2 * f(3, 4); });
178+
173179
class AbstractBase {
174180
public:
175181
// [workaround(intel)] = default does not work here

tests/test_callbacks.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ def test_cpp_callable_cleanup():
103103
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]
104104

105105

106+
def test_cpp_correct_overload_resolution():
107+
def f(a):
108+
return a
109+
110+
assert m.dummy_function_overloaded_std_func_arg(f) == 9
111+
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9
112+
113+
def f2(a, b):
114+
return a + b
115+
116+
assert m.dummy_function_overloaded_std_func_arg(f2) == 14
117+
assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14
118+
119+
106120
def test_cpp_function_roundtrip():
107121
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""
108122

@@ -131,7 +145,10 @@ def test_cpp_function_roundtrip():
131145
m.test_dummy_function(lambda x, y: x + y)
132146
assert any(
133147
s in str(excinfo.value)
134-
for s in ("missing 1 required positional argument", "takes exactly 2 arguments")
148+
for s in (
149+
"incompatible function arguments. The following argument types are",
150+
"function test_cpp_function_roundtrip.<locals>.<lambda>",
151+
)
135152
)
136153

137154

tests/test_embed/test_interpreter.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <pybind11/embed.h>
2+
#include <pybind11/functional.h>
23

34
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
45
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
@@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
7879
d["missing"].cast<py::object>();
7980
}
8081

82+
PYBIND11_EMBEDDED_MODULE(func_module, m) {
83+
m.def("funcOverload", [](const std::function<int(int, int)> f) {
84+
return f(2, 3);
85+
}).def("funcOverload", [](const std::function<int(int)> f) { return f(2); });
86+
}
87+
8188
TEST_CASE("PYTHONPATH is used to update sys.path") {
8289
// The setup for this TEST_CASE is in catch.cpp!
8390
auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
@@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") {
171178
py::initialize_interpreter();
172179
}
173180

181+
TEST_CASE("Check the overload resolution from cpp_function objects to std::function") {
182+
auto m = py::module_::import("func_module");
183+
auto f = std::function<int(int)>([](int x) { return 2 * x; });
184+
REQUIRE(m.attr("funcOverload")(f).template cast<int>() == 4);
185+
186+
auto f2 = std::function<int(int, int)>([](int x, int y) { return 2 * x * y; });
187+
REQUIRE(m.attr("funcOverload")(f2).template cast<int>() == 12);
188+
}
189+
174190
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
175191
TEST_CASE("Custom PyConfig") {
176192
py::finalize_interpreter();

0 commit comments

Comments
 (0)