Skip to content

Commit 0b49140

Browse files
committed
Added tests for ArrayLike in signatures and fixed wrong types for Refs
1 parent 6771709 commit 0b49140

File tree

7 files changed

+82
-9
lines changed

7 files changed

+82
-9
lines changed

Diff for: include/pybind11/eigen/matrix.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,9 @@ struct eigen_map_caster {
441441
}
442442
}
443443

444-
static constexpr auto name = props::descriptor;
444+
// return_descr forces the use of NDArray instead of ArrayLike in args
445+
// since Ref<...> args can only accept arrays.
446+
static constexpr auto name = return_descr(props::descriptor);
445447

446448
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
447449
// types but not bound arguments). We still provide them (with an explicitly delete) so that

Diff for: include/pybind11/eigen/tensor.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,10 @@ struct type_caster<Eigen::TensorMap<Type, Options>,
505505
std::unique_ptr<MapType> value;
506506

507507
public:
508-
static constexpr auto name = get_tensor_descriptor<Type, true, needs_writeable>::value;
508+
// return_descr forces the use of NDArray instead of ArrayLike since refs can only reference
509+
// arrays
510+
static constexpr auto name
511+
= return_descr(get_tensor_descriptor<Type, true, needs_writeable>::value);
509512
explicit operator MapType *() { return value.get(); }
510513
explicit operator MapType &() { return *value; }
511514
explicit operator MapType &&() && { return std::move(*value); }

Diff for: tests/test_eigen_matrix.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -440,4 +440,8 @@ TEST_SUBMODULE(eigen_matrix, m) {
440440
py::module_::import("numpy").attr("ones")(10);
441441
return v[0](5);
442442
});
443+
m.def("round_trip_vector", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return x; });
444+
m.def("round_trip_dense", [](const DenseMatrixR &m) -> DenseMatrixR { return m; });
445+
m.def("round_trip_dense_ref",
446+
[](const Eigen::Ref<DenseMatrixR> &m) -> Eigen::Ref<DenseMatrixR> { return m; });
443447
}

Diff for: tests/test_eigen_matrix.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,19 @@ def test_mutator_descriptors():
9595
with pytest.raises(TypeError) as excinfo:
9696
m.fixed_mutator_r(zc)
9797
assert (
98-
'(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]",'
98+
'(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[5, 6]",'
9999
' "flags.writeable", "flags.c_contiguous"]) -> None' in str(excinfo.value)
100100
)
101101
with pytest.raises(TypeError) as excinfo:
102102
m.fixed_mutator_c(zr)
103103
assert (
104-
'(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]",'
104+
'(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[5, 6]",'
105105
' "flags.writeable", "flags.f_contiguous"]) -> None' in str(excinfo.value)
106106
)
107107
with pytest.raises(TypeError) as excinfo:
108108
m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype="float32"))
109109
assert (
110-
'(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]", "flags.writeable"]) -> None'
110+
'(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[5, 6]", "flags.writeable"]) -> None'
111111
in str(excinfo.value)
112112
)
113113
zr.flags.writeable = False
@@ -202,7 +202,7 @@ def test_negative_stride_from_python(msg):
202202
msg(excinfo.value)
203203
== """
204204
double_threer(): incompatible function arguments. The following argument types are supported:
205-
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[1, 3]", "flags.writeable"]) -> None
205+
1. (arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[1, 3]", "flags.writeable"]) -> None
206206
207207
Invoked with: """
208208
+ repr(np.array([5.0, 4.0, 3.0], dtype="float32"))
@@ -214,7 +214,7 @@ def test_negative_stride_from_python(msg):
214214
msg(excinfo.value)
215215
== """
216216
double_threec(): incompatible function arguments. The following argument types are supported:
217-
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[3, 1]", "flags.writeable"]) -> None
217+
1. (arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[3, 1]", "flags.writeable"]) -> None
218218
219219
Invoked with: """
220220
+ repr(np.array([7.0, 4.0, 1.0], dtype="float32"))
@@ -818,3 +818,22 @@ def test_custom_operator_new():
818818
o = m.CustomOperatorNew()
819819
np.testing.assert_allclose(o.a, 0.0)
820820
np.testing.assert_allclose(o.b.diagonal(), 1.0)
821+
822+
823+
def test_arraylike_signature(doc):
824+
assert doc(m.round_trip_vector) == (
825+
'round_trip_vector(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, 1]"])'
826+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, 1]"]'
827+
)
828+
assert doc(m.round_trip_dense) == (
829+
'round_trip_dense(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"])'
830+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]'
831+
)
832+
assert doc(m.round_trip_dense_ref) == (
833+
'round_trip_dense_ref(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"])'
834+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]'
835+
)
836+
m.round_trip_vector([1.0, 2.0])
837+
m.round_trip_dense([[1.0, 2.0], [3.0, 4.0]])
838+
with pytest.raises(TypeError, match="incompatible function arguments"):
839+
m.round_trip_dense_ref([[1.0, 2.0], [3.0, 4.0]])

Diff for: tests/test_eigen_tensor.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,32 @@ def test_doc_string(m, doc):
285285

286286
order_flag = f'"flags.{m.needed_options.lower()}_contiguous"'
287287
assert doc(m.round_trip_view_tensor) == (
288-
f'round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, "[?, ?, ?]", "flags.writeable", {order_flag}])'
288+
f'round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}])'
289289
f' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}]'
290290
)
291291
assert doc(m.round_trip_const_view_tensor) == (
292-
f'round_trip_const_view_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, "[?, ?, ?]", {order_flag}])'
292+
f'round_trip_const_view_tensor(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", {order_flag}])'
293293
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
294294
)
295+
296+
297+
@pytest.mark.parametrize("m", submodules)
298+
def test_arraylike_signature(m, doc):
299+
order_flag = f'"flags.{m.needed_options.lower()}_contiguous"'
300+
assert doc(m.round_trip_tensor) == (
301+
'round_trip_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, "[?, ?, ?]"])'
302+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
303+
)
304+
assert doc(m.round_trip_tensor_noconvert) == (
305+
'round_trip_tensor_noconvert(tensor: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"])'
306+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
307+
)
308+
assert doc(m.round_trip_view_tensor) == (
309+
f'round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}])'
310+
f' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}]'
311+
)
312+
m.round_trip_tensor(tensor_ref.tolist())
313+
with pytest.raises(TypeError, match="incompatible function arguments"):
314+
m.round_trip_tensor_noconvert(tensor_ref.tolist())
315+
with pytest.raises(TypeError, match="incompatible function arguments"):
316+
m.round_trip_view_tensor(tensor_ref.tolist())

Diff for: tests/test_numpy_array.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -586,4 +586,13 @@ TEST_SUBMODULE(numpy_array, sm) {
586586
sm.def("return_array_pyobject_ptr_from_list", return_array_from_list<PyObject *>);
587587
sm.def("return_array_handle_from_list", return_array_from_list<py::handle>);
588588
sm.def("return_array_object_from_list", return_array_from_list<py::object>);
589+
590+
sm.def(
591+
"round_trip_array_t",
592+
[](const py::array_t<float> &x) -> py::array_t<float> { return x; },
593+
py::arg("x"));
594+
sm.def(
595+
"round_trip_array_t_noconvert",
596+
[](const py::array_t<float> &x) -> py::array_t<float> { return x; },
597+
py::arg("x").noconvert());
589598
}

Diff for: tests/test_numpy_array.py

+14
Original file line numberDiff line numberDiff line change
@@ -687,3 +687,17 @@ def test_return_array_object_cpp_loop(return_array, unwrap):
687687
assert isinstance(arr_from_list, np.ndarray)
688688
assert arr_from_list.dtype == np.dtype("O")
689689
assert unwrap(arr_from_list) == [6, "seven", -8.0]
690+
691+
692+
def test_arraylike_signature(doc):
693+
assert (
694+
doc(m.round_trip_array_t)
695+
== "round_trip_array_t(x: typing.Annotated[numpy.typing.ArrayLike, numpy.float32]) -> numpy.typing.NDArray[numpy.float32]"
696+
)
697+
assert (
698+
doc(m.round_trip_array_t_noconvert)
699+
== "round_trip_array_t_noconvert(x: numpy.typing.NDArray[numpy.float32]) -> numpy.typing.NDArray[numpy.float32]"
700+
)
701+
m.round_trip_array_t([1, 2, 3])
702+
with pytest.raises(TypeError, match="incompatible function arguments"):
703+
m.round_trip_array_t_noconvert([1, 2, 3])

0 commit comments

Comments
 (0)