Skip to content

Commit 0235533

Browse files
author
AlexanderMueller
committed
add argument number dispatch mechanism for std::function casting
1 parent a1d0091 commit 0235533

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

include/pybind11/functional.h

+40-1
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,17 @@ struct type_caster<std::function<Return(Args...)>> {
101101
if (detail::is_function_record_capsule(c)) {
102102
rec = c.get_pointer<function_record>();
103103
}
104-
105104
while (rec != nullptr) {
105+
const int correctingSelfArgument = rec->is_method ? 1 : 0;
106+
if (rec->nargs - correctingSelfArgument != sizeof...(Args)) {
107+
rec = rec->next;
108+
// if the overload is not feasible in terms of number of arguments, we
109+
// continue to the next one. If there is no next one, we return false.
110+
if (rec == nullptr) {
111+
return false;
112+
}
113+
continue;
114+
}
106115
if (rec->is_stateless
107116
&& same_type(typeid(function_type),
108117
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
@@ -118,6 +127,36 @@ struct type_caster<std::function<Return(Args...)>> {
118127
// PYPY segfaults here when passing builtin function like sum.
119128
// Raising an fail exception here works to prevent the segfault, but only on gcc.
120129
// See PR #1413 for full details
130+
} else {
131+
// Check number of arguments of Python function
132+
auto getArgCount = [&](PyObject *obj) {
133+
// This is faster then doing import inspect and inspect.signature(obj).parameters
134+
auto *t = PyObject_GetAttrString(obj, "__code__");
135+
auto *argCount = PyObject_GetAttrString(t, "co_argcount");
136+
return PyLong_AsLong(argCount);
137+
};
138+
long argCount = -1;
139+
140+
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__code__"))) {
141+
argCount = getArgCount(src.ptr());
142+
} else {
143+
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__call__"))) {
144+
auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__");
145+
argCount = getArgCount(t2) - 1; // we have to remove the self argument
146+
} else {
147+
// No __code__ or __call__ attribute, this is not a proper Python function
148+
return false;
149+
}
150+
}
151+
// if we are a method, we have to correct the argument count since we are not counting
152+
// the self argument
153+
const int correctingSelfArgument
154+
= static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
155+
156+
argCount -= correctingSelfArgument;
157+
if (argCount != sizeof...(Args)) {
158+
return false;
159+
}
121160
}
122161

123162
value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(

tests/test_callbacks.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ TEST_SUBMODULE(callbacks, m) {
170170
return "argument does NOT match dummy_function. This should never happen!";
171171
});
172172

173+
// test_cpp_correct_overload_resolution
174+
m.def("dummy_function_overloaded_std_func_arg",
175+
[](const std::function<int(int)> &f) { return 3 * f(3); });
176+
m.def("dummy_function_overloaded_std_func_arg",
177+
[](const std::function<int(int, int)> &f) { return 2 * f(3, 4); });
178+
173179
class AbstractBase {
174180
public:
175181
// [workaround(intel)] = default does not work here

tests/test_callbacks.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ def test_cpp_callable_cleanup():
103103
assert alive_counts == [0, 1, 2, 1, 2, 1, 0]
104104

105105

106+
def test_cpp_correct_overload_resolution():
107+
def f(a):
108+
return a
109+
110+
assert m.dummy_function_overloaded_std_func_arg(f) == 9
111+
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9
112+
113+
def f2(a, b):
114+
return a + b
115+
116+
assert m.dummy_function_overloaded_std_func_arg(f2) == 14
117+
assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14
118+
119+
106120
def test_cpp_function_roundtrip():
107121
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""
108122

@@ -131,7 +145,10 @@ def test_cpp_function_roundtrip():
131145
m.test_dummy_function(lambda x, y: x + y)
132146
assert any(
133147
s in str(excinfo.value)
134-
for s in ("missing 1 required positional argument", "takes exactly 2 arguments")
148+
for s in (
149+
"incompatible function arguments. The following argument types are",
150+
"function test_cpp_function_roundtrip.<locals>.<lambda>",
151+
)
135152
)
136153

137154

tests/test_embed/test_interpreter.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <pybind11/embed.h>
2+
#include <pybind11/functional.h>
23

34
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
45
// catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
@@ -78,6 +79,12 @@ PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
7879
d["missing"].cast<py::object>();
7980
}
8081

82+
PYBIND11_EMBEDDED_MODULE(func_module, m) {
83+
m.def("funcOverload", [](const std::function<int(int, int)> &f) {
84+
return f(2, 3);
85+
}).def("funcOverload", [](const std::function<int(int)> &f) { return f(2); });
86+
}
87+
8188
TEST_CASE("PYTHONPATH is used to update sys.path") {
8289
// The setup for this TEST_CASE is in catch.cpp!
8390
auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
@@ -171,6 +178,15 @@ TEST_CASE("There can be only one interpreter") {
171178
py::initialize_interpreter();
172179
}
173180

181+
TEST_CASE("Check the overload resolution from cpp_function objects to std::function") {
182+
auto m = py::module_::import("func_module");
183+
auto f = std::function<int(int)>([](int x) { return 2 * x; });
184+
REQUIRE(m.attr("funcOverload")(f).template cast<int>() == 4);
185+
186+
auto f2 = std::function<int(int, int)>([](int x, int y) { return 2 * x * y; });
187+
REQUIRE(m.attr("funcOverload")(f2).template cast<int>() == 12);
188+
}
189+
174190
#if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
175191
TEST_CASE("Custom PyConfig") {
176192
py::finalize_interpreter();

0 commit comments

Comments
 (0)