Skip to content

Commit 2187294

Browse files
author
AlexanderMueller
committed
changes from review
1 parent 0235533 commit 2187294

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

include/pybind11/functional.h

+19-16
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ struct type_caster<std::function<Return(Args...)>> {
102102
rec = c.get_pointer<function_record>();
103103
}
104104
while (rec != nullptr) {
105-
const int correctingSelfArgument = rec->is_method ? 1 : 0;
106-
if (rec->nargs - correctingSelfArgument != sizeof...(Args)) {
105+
const size_t self_offset = rec->is_method ? 1 : 0;
106+
if (rec->nargs != sizeof...(Args) + self_offset) {
107107
rec = rec->next;
108108
// if the overload is not feasible in terms of number of arguments, we
109109
// continue to the next one. If there is no next one, we return false.
@@ -129,31 +129,34 @@ struct type_caster<std::function<Return(Args...)>> {
129129
// See PR #1413 for full details
130130
} else {
131131
// Check number of arguments of Python function
132-
auto getArgCount = [&](PyObject *obj) {
133-
// This is faster then doing import inspect and inspect.signature(obj).parameters
134-
auto *t = PyObject_GetAttrString(obj, "__code__");
135-
auto *argCount = PyObject_GetAttrString(t, "co_argcount");
136-
return PyLong_AsLong(argCount);
132+
auto argCountFromFuncCode = [&](handle &obj) {
133+
// This is faster then doing import inspect and
134+
// inspect.signature(obj).parameters
135+
136+
object argCount = obj.attr("co_argcount");
137+
return argCount.template cast<size_t>();
137138
};
138-
long argCount = -1;
139+
size_t argCount = -1;
139140

140-
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__code__"))) {
141-
argCount = getArgCount(src.ptr());
141+
handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__");
142+
if (codeAttr) {
143+
argCount = argCountFromFuncCode(codeAttr);
142144
} else {
143-
if (static_cast<bool>(PyObject_HasAttrString(src.ptr(), "__call__"))) {
144-
auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__");
145-
argCount = getArgCount(t2) - 1; // we have to remove the self argument
145+
handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__");
146+
if (callAttr) {
147+
handle codeAttr2 = callAttr.attr("__code__");
148+
argCount = argCountFromFuncCode(codeAttr2)
149+
- 1; // we have to remove the self argument
146150
} else {
147151
// No __code__ or __call__ attribute, this is not a proper Python function
148152
return false;
149153
}
150154
}
151155
// if we are a method, we have to correct the argument count since we are not counting
152156
// the self argument
153-
const int correctingSelfArgument
154-
= static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
157+
const size_t self_offset = static_cast<bool>(PyMethod_Check(src.ptr())) ? 1 : 0;
155158

156-
argCount -= correctingSelfArgument;
159+
argCount -= self_offset;
157160
if (argCount != sizeof...(Args)) {
158161
return false;
159162
}

tests/test_callbacks.py

+10
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,23 @@ def test_cpp_correct_overload_resolution():
107107
def f(a):
108108
return a
109109

110+
class A:
111+
def __call__(self, a):
112+
return a
113+
110114
assert m.dummy_function_overloaded_std_func_arg(f) == 9
115+
assert m.dummy_function_overloaded_std_func_arg(A()) == 9
111116
assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9
112117

113118
def f2(a, b):
114119
return a + b
115120

121+
class B:
122+
def __call__(self, a, b):
123+
return a + b
124+
116125
assert m.dummy_function_overloaded_std_func_arg(f2) == 14
126+
assert m.dummy_function_overloaded_std_func_arg(B()) == 14
117127
assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14
118128

119129

0 commit comments

Comments
 (0)