Skip to content

Commit 2415242

Browse files
InvincibleRMCpre-commit-ci[bot]timohl
authored
feat(types) Numpy.typing.NDArray (#5212)
* tests passing * lint * add comment * remove empty tuple[()] * test io_name * style: pre-commit fixes * remove accidental > Signed-off-by: Michael Carlstrom <[email protected]> * try T * make both const_name Signed-off-by: Michael Carlstrom <[email protected]> * try and treat as string * style: pre-commit fixes * Update Numpy type hints * style: pre-commit fixes * re-run ci Signed-off-by: Michael Carlstrom <[email protected]> * re-run ci Signed-off-by: Michael Carlstrom <[email protected]> * remove escape characters * Added tests for ArrayLike in signatures and fixed wrong types for Refs --------- Signed-off-by: Michael Carlstrom <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Ohliger <[email protected]>
1 parent 34a118f commit 2415242

10 files changed

+142
-61
lines changed

include/pybind11/eigen/matrix.h

+14-9
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,22 @@ struct EigenProps {
225225
= !show_c_contiguous && show_order && requires_col_major;
226226

227227
static constexpr auto descriptor
228-
= const_name("numpy.ndarray[") + npy_format_descriptor<Scalar>::name + const_name("[")
228+
= const_name("typing.Annotated[")
229+
+ io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
230+
+ npy_format_descriptor<Scalar>::name + io_name("", "]") + const_name(", \"[")
229231
+ const_name<fixed_rows>(const_name<(size_t) rows>(), const_name("m")) + const_name(", ")
230-
+ const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n")) + const_name("]")
231-
+
232+
+ const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n"))
233+
+ const_name("]\"")
232234
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to
233235
// be satisfied: writeable=True (for a mutable reference), and, depending on the map's
234236
// stride options, possibly f_contiguous or c_contiguous. We include them in the
235237
// descriptor output to provide some hint as to why a TypeError is occurring (otherwise
236-
// it can be confusing to see that a function accepts a 'numpy.ndarray[float64[3,2]]' and
237-
// an error message that you *gave* a numpy.ndarray of the right type and dimensions.
238-
const_name<show_writeable>(", flags.writeable", "")
239-
+ const_name<show_c_contiguous>(", flags.c_contiguous", "")
240-
+ const_name<show_f_contiguous>(", flags.f_contiguous", "") + const_name("]");
238+
// it can be confusing to see that a function accepts a
239+
// 'typing.Annotated[numpy.typing.NDArray[numpy.float64], "[3,2]"]' and an error message
240+
// that you *gave* a numpy.ndarray of the right type and dimensions.
241+
+ const_name<show_writeable>(", \"flags.writeable\"", "")
242+
+ const_name<show_c_contiguous>(", \"flags.c_contiguous\"", "")
243+
+ const_name<show_f_contiguous>(", \"flags.f_contiguous\"", "") + const_name("]");
241244
};
242245

243246
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
@@ -441,7 +444,9 @@ struct eigen_map_caster {
441444
}
442445
}
443446

444-
static constexpr auto name = props::descriptor;
447+
// return_descr forces the use of NDArray instead of ArrayLike in args
448+
// since Ref<...> args can only accept arrays.
449+
static constexpr auto name = return_descr(props::descriptor);
445450

446451
// Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
447452
// types but not bound arguments). We still provide them (with an explicitly delete) so that

include/pybind11/eigen/tensor.h

+12-6
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,16 @@ struct eigen_tensor_helper<
124124
template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
125125
struct get_tensor_descriptor {
126126
static constexpr auto details
127-
= const_name<NeedsWriteable>(", flags.writeable", "") + const_name
127+
= const_name<NeedsWriteable>(", \"flags.writeable\"", "") + const_name
128128
< static_cast<int>(Type::Layout)
129-
== static_cast<int>(Eigen::RowMajor) > (", flags.c_contiguous", ", flags.f_contiguous");
129+
== static_cast<int>(Eigen::RowMajor)
130+
> (", \"flags.c_contiguous\"", ", \"flags.f_contiguous\"");
130131
static constexpr auto value
131-
= const_name("numpy.ndarray[") + npy_format_descriptor<typename Type::Scalar>::name
132-
+ const_name("[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
133-
+ const_name("]") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
132+
= const_name("typing.Annotated[")
133+
+ io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
134+
+ npy_format_descriptor<typename Type::Scalar>::name + io_name("", "]")
135+
+ const_name(", \"[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
136+
+ const_name("]\"") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
134137
};
135138

136139
// When EIGEN_AVOID_STL_ARRAY is defined, Eigen::DSizes<T, 0> does not have the begin() member
@@ -502,7 +505,10 @@ struct type_caster<Eigen::TensorMap<Type, Options>,
502505
std::unique_ptr<MapType> value;
503506

504507
public:
505-
static constexpr auto name = get_tensor_descriptor<Type, true, needs_writeable>::value;
508+
// return_descr forces the use of NDArray instead of ArrayLike since refs can only reference
509+
// arrays
510+
static constexpr auto name
511+
= return_descr(get_tensor_descriptor<Type, true, needs_writeable>::value);
506512
explicit operator MapType *() { return value.get(); }
507513
explicit operator MapType &() { return *value; }
508514
explicit operator MapType &&() && { return std::move(*value); }

include/pybind11/numpy.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ inline numpy_internals &get_numpy_internals() {
175175
PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
176176
module_ numpy = module_::import("numpy");
177177
str version_string = numpy.attr("__version__");
178-
179178
module_ numpy_lib = module_::import("numpy.lib");
180179
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
181180
int major_version = numpy_version.attr("major").cast<int>();
@@ -2183,7 +2182,8 @@ vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Retur
21832182
template <typename T, int Flags>
21842183
struct handle_type_name<array_t<T, Flags>> {
21852184
static constexpr auto name
2186-
= const_name("numpy.ndarray[") + npy_format_descriptor<T>::name + const_name("]");
2185+
= io_name("typing.Annotated[numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
2186+
+ npy_format_descriptor<T>::name + const_name("]");
21872187
};
21882188

21892189
PYBIND11_NAMESPACE_END(detail)

tests/test_eigen_matrix.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -440,4 +440,8 @@ TEST_SUBMODULE(eigen_matrix, m) {
440440
py::module_::import("numpy").attr("ones")(10);
441441
return v[0](5);
442442
});
443+
m.def("round_trip_vector", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return x; });
444+
m.def("round_trip_dense", [](const DenseMatrixR &m) -> DenseMatrixR { return m; });
445+
m.def("round_trip_dense_ref",
446+
[](const Eigen::Ref<DenseMatrixR> &m) -> Eigen::Ref<DenseMatrixR> { return m; });
443447
}

tests/test_eigen_matrix.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,20 @@ def test_mutator_descriptors():
9595
with pytest.raises(TypeError) as excinfo:
9696
m.fixed_mutator_r(zc)
9797
assert (
98-
"(arg0: numpy.ndarray[numpy.float32[5, 6],"
99-
" flags.writeable, flags.c_contiguous]) -> None" in str(excinfo.value)
98+
'(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[5, 6]",'
99+
' "flags.writeable", "flags.c_contiguous"]) -> None' in str(excinfo.value)
100100
)
101101
with pytest.raises(TypeError) as excinfo:
102102
m.fixed_mutator_c(zr)
103103
assert (
104-
"(arg0: numpy.ndarray[numpy.float32[5, 6],"
105-
" flags.writeable, flags.f_contiguous]) -> None" in str(excinfo.value)
104+
'(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[5, 6]",'
105+
' "flags.writeable", "flags.f_contiguous"]) -> None' in str(excinfo.value)
106106
)
107107
with pytest.raises(TypeError) as excinfo:
108108
m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype="float32"))
109-
assert "(arg0: numpy.ndarray[numpy.float32[5, 6], flags.writeable]) -> None" in str(
110-
excinfo.value
109+
assert (
110+
'(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[5, 6]", "flags.writeable"]) -> None'
111+
in str(excinfo.value)
111112
)
112113
zr.flags.writeable = False
113114
with pytest.raises(TypeError):
@@ -201,7 +202,7 @@ def test_negative_stride_from_python(msg):
201202
msg(excinfo.value)
202203
== """
203204
double_threer(): incompatible function arguments. The following argument types are supported:
204-
1. (arg0: numpy.ndarray[numpy.float32[1, 3], flags.writeable]) -> None
205+
1. (arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[1, 3]", "flags.writeable"]) -> None
205206
206207
Invoked with: """
207208
+ repr(np.array([5.0, 4.0, 3.0], dtype="float32"))
@@ -213,7 +214,7 @@ def test_negative_stride_from_python(msg):
213214
msg(excinfo.value)
214215
== """
215216
double_threec(): incompatible function arguments. The following argument types are supported:
216-
1. (arg0: numpy.ndarray[numpy.float32[3, 1], flags.writeable]) -> None
217+
1. (arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[3, 1]", "flags.writeable"]) -> None
217218
218219
Invoked with: """
219220
+ repr(np.array([7.0, 4.0, 1.0], dtype="float32"))
@@ -634,37 +635,37 @@ def test_nocopy_wrapper():
634635
with pytest.raises(TypeError) as excinfo:
635636
m.get_elem_nocopy(int_matrix_colmajor)
636637
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
637-
assert ", flags.f_contiguous" in str(excinfo.value)
638+
assert ', "flags.f_contiguous"' in str(excinfo.value)
638639
assert m.get_elem_nocopy(dbl_matrix_colmajor) == 8
639640
with pytest.raises(TypeError) as excinfo:
640641
m.get_elem_nocopy(int_matrix_rowmajor)
641642
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
642-
assert ", flags.f_contiguous" in str(excinfo.value)
643+
assert ', "flags.f_contiguous"' in str(excinfo.value)
643644
with pytest.raises(TypeError) as excinfo:
644645
m.get_elem_nocopy(dbl_matrix_rowmajor)
645646
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
646-
assert ", flags.f_contiguous" in str(excinfo.value)
647+
assert ', "flags.f_contiguous"' in str(excinfo.value)
647648

648649
# For the row-major test, we take a long matrix in row-major, so only the third is allowed:
649650
with pytest.raises(TypeError) as excinfo:
650651
m.get_elem_rm_nocopy(int_matrix_colmajor)
651652
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
652653
excinfo.value
653654
)
654-
assert ", flags.c_contiguous" in str(excinfo.value)
655+
assert ', "flags.c_contiguous"' in str(excinfo.value)
655656
with pytest.raises(TypeError) as excinfo:
656657
m.get_elem_rm_nocopy(dbl_matrix_colmajor)
657658
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
658659
excinfo.value
659660
)
660-
assert ", flags.c_contiguous" in str(excinfo.value)
661+
assert ', "flags.c_contiguous"' in str(excinfo.value)
661662
assert m.get_elem_rm_nocopy(int_matrix_rowmajor) == 8
662663
with pytest.raises(TypeError) as excinfo:
663664
m.get_elem_rm_nocopy(dbl_matrix_rowmajor)
664665
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
665666
excinfo.value
666667
)
667-
assert ", flags.c_contiguous" in str(excinfo.value)
668+
assert ', "flags.c_contiguous"' in str(excinfo.value)
668669

669670

670671
def test_eigen_ref_life_support():
@@ -700,25 +701,25 @@ def test_dense_signature(doc):
700701
assert (
701702
doc(m.double_col)
702703
== """
703-
double_col(arg0: numpy.ndarray[numpy.float32[m, 1]]) -> numpy.ndarray[numpy.float32[m, 1]]
704+
double_col(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, 1]"]) -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, 1]"]
704705
"""
705706
)
706707
assert (
707708
doc(m.double_row)
708709
== """
709-
double_row(arg0: numpy.ndarray[numpy.float32[1, n]]) -> numpy.ndarray[numpy.float32[1, n]]
710+
double_row(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[1, n]"]) -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[1, n]"]
710711
"""
711712
)
712713
assert doc(m.double_complex) == (
713714
"""
714-
double_complex(arg0: numpy.ndarray[numpy.complex64[m, 1]])"""
715-
""" -> numpy.ndarray[numpy.complex64[m, 1]]
715+
double_complex(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex64, "[m, 1]"])"""
716+
""" -> typing.Annotated[numpy.typing.NDArray[numpy.complex64], "[m, 1]"]
716717
"""
717718
)
718719
assert doc(m.double_mat_rm) == (
719720
"""
720-
double_mat_rm(arg0: numpy.ndarray[numpy.float32[m, n]])"""
721-
""" -> numpy.ndarray[numpy.float32[m, n]]
721+
double_mat_rm(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"])"""
722+
""" -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
722723
"""
723724
)
724725

@@ -817,3 +818,22 @@ def test_custom_operator_new():
817818
o = m.CustomOperatorNew()
818819
np.testing.assert_allclose(o.a, 0.0)
819820
np.testing.assert_allclose(o.b.diagonal(), 1.0)
821+
822+
823+
def test_arraylike_signature(doc):
824+
assert doc(m.round_trip_vector) == (
825+
'round_trip_vector(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, 1]"])'
826+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, 1]"]'
827+
)
828+
assert doc(m.round_trip_dense) == (
829+
'round_trip_dense(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"])'
830+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]'
831+
)
832+
assert doc(m.round_trip_dense_ref) == (
833+
'round_trip_dense_ref(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"])'
834+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]'
835+
)
836+
m.round_trip_vector([1.0, 2.0])
837+
m.round_trip_dense([[1.0, 2.0], [3.0, 4.0]])
838+
with pytest.raises(TypeError, match="incompatible function arguments"):
839+
m.round_trip_dense_ref([[1.0, 2.0], [3.0, 4.0]])

tests/test_eigen_tensor.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -271,23 +271,46 @@ def test_round_trip_references_actually_refer(m):
271271
@pytest.mark.parametrize("m", submodules)
272272
def test_doc_string(m, doc):
273273
assert (
274-
doc(m.copy_tensor) == "copy_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
274+
doc(m.copy_tensor)
275+
== 'copy_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
275276
)
276277
assert (
277278
doc(m.copy_fixed_tensor)
278-
== "copy_fixed_tensor() -> numpy.ndarray[numpy.float64[3, 5, 2]]"
279+
== 'copy_fixed_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[3, 5, 2]"]'
279280
)
280281
assert (
281282
doc(m.reference_const_tensor)
282-
== "reference_const_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
283+
== 'reference_const_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
283284
)
284285

285-
order_flag = f"flags.{m.needed_options.lower()}_contiguous"
286+
order_flag = f'"flags.{m.needed_options.lower()}_contiguous"'
286287
assert doc(m.round_trip_view_tensor) == (
287-
f"round_trip_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}])"
288-
f" -> numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}]"
288+
f'round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}])'
289+
f' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}]'
289290
)
290291
assert doc(m.round_trip_const_view_tensor) == (
291-
f"round_trip_const_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], {order_flag}])"
292-
" -> numpy.ndarray[numpy.float64[?, ?, ?]]"
292+
f'round_trip_const_view_tensor(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", {order_flag}])'
293+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
293294
)
295+
296+
297+
@pytest.mark.parametrize("m", submodules)
298+
def test_arraylike_signature(m, doc):
299+
order_flag = f'"flags.{m.needed_options.lower()}_contiguous"'
300+
assert doc(m.round_trip_tensor) == (
301+
'round_trip_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, "[?, ?, ?]"])'
302+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
303+
)
304+
assert doc(m.round_trip_tensor_noconvert) == (
305+
'round_trip_tensor_noconvert(tensor: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"])'
306+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
307+
)
308+
assert doc(m.round_trip_view_tensor) == (
309+
f'round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}])'
310+
f' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}]'
311+
)
312+
m.round_trip_tensor(tensor_ref.tolist())
313+
with pytest.raises(TypeError, match="incompatible function arguments"):
314+
m.round_trip_tensor_noconvert(tensor_ref.tolist())
315+
with pytest.raises(TypeError, match="incompatible function arguments"):
316+
m.round_trip_view_tensor(tensor_ref.tolist())

tests/test_numpy_array.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -586,4 +586,13 @@ TEST_SUBMODULE(numpy_array, sm) {
586586
sm.def("return_array_pyobject_ptr_from_list", return_array_from_list<PyObject *>);
587587
sm.def("return_array_handle_from_list", return_array_from_list<py::handle>);
588588
sm.def("return_array_object_from_list", return_array_from_list<py::object>);
589+
590+
sm.def(
591+
"round_trip_array_t",
592+
[](const py::array_t<float> &x) -> py::array_t<float> { return x; },
593+
py::arg("x"));
594+
sm.def(
595+
"round_trip_array_t_noconvert",
596+
[](const py::array_t<float> &x) -> py::array_t<float> { return x; },
597+
py::arg("x").noconvert());
589598
}

tests/test_numpy_array.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,13 @@ def test_overload_resolution(msg):
321321
msg(excinfo.value)
322322
== """
323323
overloaded(): incompatible function arguments. The following argument types are supported:
324-
1. (arg0: numpy.ndarray[numpy.float64]) -> str
325-
2. (arg0: numpy.ndarray[numpy.float32]) -> str
326-
3. (arg0: numpy.ndarray[numpy.int32]) -> str
327-
4. (arg0: numpy.ndarray[numpy.uint16]) -> str
328-
5. (arg0: numpy.ndarray[numpy.int64]) -> str
329-
6. (arg0: numpy.ndarray[numpy.complex128]) -> str
330-
7. (arg0: numpy.ndarray[numpy.complex64]) -> str
324+
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]) -> str
325+
2. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32]) -> str
326+
3. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int32]) -> str
327+
4. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.uint16]) -> str
328+
5. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int64]) -> str
329+
6. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex128]) -> str
330+
7. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex64]) -> str
331331
332332
Invoked with: 'not an array'
333333
"""
@@ -343,8 +343,8 @@ def test_overload_resolution(msg):
343343
assert m.overloaded3(np.array([1], dtype="intc")) == "int"
344344
expected_exc = """
345345
overloaded3(): incompatible function arguments. The following argument types are supported:
346-
1. (arg0: numpy.ndarray[numpy.int32]) -> str
347-
2. (arg0: numpy.ndarray[numpy.float64]) -> str
346+
1. (arg0: numpy.typing.NDArray[numpy.int32]) -> str
347+
2. (arg0: numpy.typing.NDArray[numpy.float64]) -> str
348348
349349
Invoked with: """
350350

@@ -528,7 +528,7 @@ def test_index_using_ellipsis():
528528
],
529529
)
530530
def test_format_descriptors_for_floating_point_types(test_func):
531-
assert "numpy.ndarray[numpy.float" in test_func.__doc__
531+
assert "numpy.typing.ArrayLike, numpy.float" in test_func.__doc__
532532

533533

534534
@pytest.mark.parametrize("forcecast", [False, True])
@@ -687,3 +687,17 @@ def test_return_array_object_cpp_loop(return_array, unwrap):
687687
assert isinstance(arr_from_list, np.ndarray)
688688
assert arr_from_list.dtype == np.dtype("O")
689689
assert unwrap(arr_from_list) == [6, "seven", -8.0]
690+
691+
692+
def test_arraylike_signature(doc):
693+
assert (
694+
doc(m.round_trip_array_t)
695+
== "round_trip_array_t(x: typing.Annotated[numpy.typing.ArrayLike, numpy.float32]) -> numpy.typing.NDArray[numpy.float32]"
696+
)
697+
assert (
698+
doc(m.round_trip_array_t_noconvert)
699+
== "round_trip_array_t_noconvert(x: numpy.typing.NDArray[numpy.float32]) -> numpy.typing.NDArray[numpy.float32]"
700+
)
701+
m.round_trip_array_t([1, 2, 3])
702+
with pytest.raises(TypeError, match="incompatible function arguments"):
703+
m.round_trip_array_t_noconvert([1, 2, 3])

tests/test_numpy_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_complex_array():
373373
def test_signature(doc):
374374
assert (
375375
doc(m.create_rec_nested)
376-
== "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
376+
== "create_rec_nested(arg0: int) -> numpy.typing.NDArray[NestedStruct]"
377377
)
378378

379379

0 commit comments

Comments
 (0)