Skip to content

Commit 6771709

Browse files
committed
Added usage of numpy.typing with typing.Annotated
1 parent 924261e commit 6771709

8 files changed

+67
-58
lines changed

include/pybind11/eigen/matrix.h

+11-8
Original file line numberDiff line numberDiff line change
@@ -225,19 +225,22 @@ struct EigenProps {
225225
= !show_c_contiguous && show_order && requires_col_major;
226226

227227
static constexpr auto descriptor
228-
= const_name("numpy.ndarray[") + npy_format_descriptor<Scalar>::name + const_name("[")
228+
= const_name("typing.Annotated[")
229+
+ io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
230+
+ npy_format_descriptor<Scalar>::name + io_name("", "]") + const_name(", \"[")
229231
+ const_name<fixed_rows>(const_name<(size_t) rows>(), const_name("m")) + const_name(", ")
230-
+ const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n")) + const_name("]")
231-
+
232+
+ const_name<fixed_cols>(const_name<(size_t) cols>(), const_name("n"))
233+
+ const_name("]\"") +
232234
// For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to
233235
// be satisfied: writeable=True (for a mutable reference), and, depending on the map's
234236
// stride options, possibly f_contiguous or c_contiguous. We include them in the
235237
// descriptor output to provide some hint as to why a TypeError is occurring (otherwise
236-
// it can be confusing to see that a function accepts a 'numpy.ndarray[float64[3,2]]' and
237-
// an error message that you *gave* a numpy.ndarray of the right type and dimensions.
238-
const_name<show_writeable>(", flags.writeable", "")
239-
+ const_name<show_c_contiguous>(", flags.c_contiguous", "")
240-
+ const_name<show_f_contiguous>(", flags.f_contiguous", "") + const_name("]");
238+
// it can be confusing to see that a function accepts a
239+
// 'typing.Annotated[numpy.typing.NDArray[numpy.float64], "[3,2]"]' and an error message
240+
// that you *gave* a numpy.ndarray of the right type and dimensions.
241+
const_name<show_writeable>(", \"flags.writeable\"", "")
242+
+ const_name<show_c_contiguous>(", \"flags.c_contiguous\"", "")
243+
+ const_name<show_f_contiguous>(", \"flags.f_contiguous\"", "") + const_name("]");
241244
};
242245

243246
// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,

include/pybind11/eigen/tensor.h

+8-5
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,16 @@ struct eigen_tensor_helper<
124124
template <typename Type, bool ShowDetails, bool NeedsWriteable = false>
125125
struct get_tensor_descriptor {
126126
static constexpr auto details
127-
= const_name<NeedsWriteable>(", flags.writeable", "") + const_name
127+
= const_name<NeedsWriteable>(", \"flags.writeable\"", "") + const_name
128128
< static_cast<int>(Type::Layout)
129-
== static_cast<int>(Eigen::RowMajor) > (", flags.c_contiguous", ", flags.f_contiguous");
129+
== static_cast<int>(Eigen::RowMajor)
130+
> (", \"flags.c_contiguous\"", ", \"flags.f_contiguous\"");
130131
static constexpr auto value
131-
= const_name("numpy.ndarray[") + npy_format_descriptor<typename Type::Scalar>::name
132-
+ const_name("[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
133-
+ const_name("]") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
132+
= const_name("typing.Annotated[")
133+
+ io_name("numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
134+
+ npy_format_descriptor<typename Type::Scalar>::name + io_name("", "]")
135+
+ const_name(", \"[") + eigen_tensor_helper<remove_cv_t<Type>>::dimensions_descriptor
136+
+ const_name("]\"") + const_name<ShowDetails>(details, const_name("")) + const_name("]");
134137
};
135138

136139
// When EIGEN_AVOID_STL_ARRAY is defined, Eigen::DSizes<T, 0> does not have the begin() member

include/pybind11/numpy.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -2183,7 +2183,8 @@ vectorize_helper<Func, Return, Args...> vectorize_extractor(const Func &f, Retur
21832183
template <typename T, int Flags>
21842184
struct handle_type_name<array_t<T, Flags>> {
21852185
static constexpr auto name
2186-
= const_name("numpy.ndarray[") + npy_format_descriptor<T>::name + const_name("]");
2186+
= io_name("typing.Annotated[numpy.typing.ArrayLike, ", "numpy.typing.NDArray[")
2187+
+ npy_format_descriptor<T>::name + const_name("]");
21872188
};
21882189

21892190
PYBIND11_NAMESPACE_END(detail)

tests/test_eigen_matrix.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,20 @@ def test_mutator_descriptors():
9595
with pytest.raises(TypeError) as excinfo:
9696
m.fixed_mutator_r(zc)
9797
assert (
98-
"(arg0: numpy.ndarray[numpy.float32[5, 6],"
99-
" flags.writeable, flags.c_contiguous]) -> None" in str(excinfo.value)
98+
'(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]",'
99+
' "flags.writeable", "flags.c_contiguous"]) -> None' in str(excinfo.value)
100100
)
101101
with pytest.raises(TypeError) as excinfo:
102102
m.fixed_mutator_c(zr)
103103
assert (
104-
"(arg0: numpy.ndarray[numpy.float32[5, 6],"
105-
" flags.writeable, flags.f_contiguous]) -> None" in str(excinfo.value)
104+
'(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]",'
105+
' "flags.writeable", "flags.f_contiguous"]) -> None' in str(excinfo.value)
106106
)
107107
with pytest.raises(TypeError) as excinfo:
108108
m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype="float32"))
109-
assert "(arg0: numpy.ndarray[numpy.float32[5, 6], flags.writeable]) -> None" in str(
110-
excinfo.value
109+
assert (
110+
'(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]", "flags.writeable"]) -> None'
111+
in str(excinfo.value)
111112
)
112113
zr.flags.writeable = False
113114
with pytest.raises(TypeError):
@@ -201,7 +202,7 @@ def test_negative_stride_from_python(msg):
201202
msg(excinfo.value)
202203
== """
203204
double_threer(): incompatible function arguments. The following argument types are supported:
204-
1. (arg0: numpy.ndarray[numpy.float32[1, 3], flags.writeable]) -> None
205+
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[1, 3]", "flags.writeable"]) -> None
205206
206207
Invoked with: """
207208
+ repr(np.array([5.0, 4.0, 3.0], dtype="float32"))
@@ -213,7 +214,7 @@ def test_negative_stride_from_python(msg):
213214
msg(excinfo.value)
214215
== """
215216
double_threec(): incompatible function arguments. The following argument types are supported:
216-
1. (arg0: numpy.ndarray[numpy.float32[3, 1], flags.writeable]) -> None
217+
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[3, 1]", "flags.writeable"]) -> None
217218
218219
Invoked with: """
219220
+ repr(np.array([7.0, 4.0, 1.0], dtype="float32"))
@@ -634,37 +635,37 @@ def test_nocopy_wrapper():
634635
with pytest.raises(TypeError) as excinfo:
635636
m.get_elem_nocopy(int_matrix_colmajor)
636637
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
637-
assert ", flags.f_contiguous" in str(excinfo.value)
638+
assert ', "flags.f_contiguous"' in str(excinfo.value)
638639
assert m.get_elem_nocopy(dbl_matrix_colmajor) == 8
639640
with pytest.raises(TypeError) as excinfo:
640641
m.get_elem_nocopy(int_matrix_rowmajor)
641642
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
642-
assert ", flags.f_contiguous" in str(excinfo.value)
643+
assert ', "flags.f_contiguous"' in str(excinfo.value)
643644
with pytest.raises(TypeError) as excinfo:
644645
m.get_elem_nocopy(dbl_matrix_rowmajor)
645646
assert "get_elem_nocopy(): incompatible function arguments." in str(excinfo.value)
646-
assert ", flags.f_contiguous" in str(excinfo.value)
647+
assert ', "flags.f_contiguous"' in str(excinfo.value)
647648

648649
# For the row-major test, we take a long matrix in row-major, so only the third is allowed:
649650
with pytest.raises(TypeError) as excinfo:
650651
m.get_elem_rm_nocopy(int_matrix_colmajor)
651652
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
652653
excinfo.value
653654
)
654-
assert ", flags.c_contiguous" in str(excinfo.value)
655+
assert ', "flags.c_contiguous"' in str(excinfo.value)
655656
with pytest.raises(TypeError) as excinfo:
656657
m.get_elem_rm_nocopy(dbl_matrix_colmajor)
657658
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
658659
excinfo.value
659660
)
660-
assert ", flags.c_contiguous" in str(excinfo.value)
661+
assert ', "flags.c_contiguous"' in str(excinfo.value)
661662
assert m.get_elem_rm_nocopy(int_matrix_rowmajor) == 8
662663
with pytest.raises(TypeError) as excinfo:
663664
m.get_elem_rm_nocopy(dbl_matrix_rowmajor)
664665
assert "get_elem_rm_nocopy(): incompatible function arguments." in str(
665666
excinfo.value
666667
)
667-
assert ", flags.c_contiguous" in str(excinfo.value)
668+
assert ', "flags.c_contiguous"' in str(excinfo.value)
668669

669670

670671
def test_eigen_ref_life_support():
@@ -700,25 +701,25 @@ def test_dense_signature(doc):
700701
assert (
701702
doc(m.double_col)
702703
== """
703-
double_col(arg0: numpy.ndarray[numpy.float32[m, 1]]) -> numpy.ndarray[numpy.float32[m, 1]]
704+
double_col(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, 1]"]) -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, 1]"]
704705
"""
705706
)
706707
assert (
707708
doc(m.double_row)
708709
== """
709-
double_row(arg0: numpy.ndarray[numpy.float32[1, n]]) -> numpy.ndarray[numpy.float32[1, n]]
710+
double_row(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[1, n]"]) -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[1, n]"]
710711
"""
711712
)
712713
assert doc(m.double_complex) == (
713714
"""
714-
double_complex(arg0: numpy.ndarray[numpy.complex64[m, 1]])"""
715-
""" -> numpy.ndarray[numpy.complex64[m, 1]]
715+
double_complex(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex64, "[m, 1]"])"""
716+
""" -> typing.Annotated[numpy.typing.NDArray[numpy.complex64], "[m, 1]"]
716717
"""
717718
)
718719
assert doc(m.double_mat_rm) == (
719720
"""
720-
double_mat_rm(arg0: numpy.ndarray[numpy.float32[m, n]])"""
721-
""" -> numpy.ndarray[numpy.float32[m, n]]
721+
double_mat_rm(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"])"""
722+
""" -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]
722723
"""
723724
)
724725

tests/test_eigen_tensor.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -271,23 +271,24 @@ def test_round_trip_references_actually_refer(m):
271271
@pytest.mark.parametrize("m", submodules)
272272
def test_doc_string(m, doc):
273273
assert (
274-
doc(m.copy_tensor) == "copy_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
274+
doc(m.copy_tensor)
275+
== 'copy_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
275276
)
276277
assert (
277278
doc(m.copy_fixed_tensor)
278-
== "copy_fixed_tensor() -> numpy.ndarray[numpy.float64[3, 5, 2]]"
279+
== 'copy_fixed_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[3, 5, 2]"]'
279280
)
280281
assert (
281282
doc(m.reference_const_tensor)
282-
== "reference_const_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
283+
== 'reference_const_tensor() -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
283284
)
284285

285-
order_flag = f"flags.{m.needed_options.lower()}_contiguous"
286+
order_flag = f'"flags.{m.needed_options.lower()}_contiguous"'
286287
assert doc(m.round_trip_view_tensor) == (
287-
f"round_trip_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}])"
288-
f" -> numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}]"
288+
f'round_trip_view_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, "[?, ?, ?]", "flags.writeable", {order_flag}])'
289+
f' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]", "flags.writeable", {order_flag}]'
289290
)
290291
assert doc(m.round_trip_const_view_tensor) == (
291-
f"round_trip_const_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], {order_flag}])"
292-
" -> numpy.ndarray[numpy.float64[?, ?, ?]]"
292+
f'round_trip_const_view_tensor(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64, "[?, ?, ?]", {order_flag}])'
293+
' -> typing.Annotated[numpy.typing.NDArray[numpy.float64], "[?, ?, ?]"]'
293294
)

tests/test_numpy_array.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,13 @@ def test_overload_resolution(msg):
321321
msg(excinfo.value)
322322
== """
323323
overloaded(): incompatible function arguments. The following argument types are supported:
324-
1. (arg0: numpy.ndarray[numpy.float64]) -> str
325-
2. (arg0: numpy.ndarray[numpy.float32]) -> str
326-
3. (arg0: numpy.ndarray[numpy.int32]) -> str
327-
4. (arg0: numpy.ndarray[numpy.uint16]) -> str
328-
5. (arg0: numpy.ndarray[numpy.int64]) -> str
329-
6. (arg0: numpy.ndarray[numpy.complex128]) -> str
330-
7. (arg0: numpy.ndarray[numpy.complex64]) -> str
324+
1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]) -> str
325+
2. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32]) -> str
326+
3. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int32]) -> str
327+
4. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.uint16]) -> str
328+
5. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int64]) -> str
329+
6. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex128]) -> str
330+
7. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.complex64]) -> str
331331
332332
Invoked with: 'not an array'
333333
"""
@@ -343,8 +343,8 @@ def test_overload_resolution(msg):
343343
assert m.overloaded3(np.array([1], dtype="intc")) == "int"
344344
expected_exc = """
345345
overloaded3(): incompatible function arguments. The following argument types are supported:
346-
1. (arg0: numpy.ndarray[numpy.int32]) -> str
347-
2. (arg0: numpy.ndarray[numpy.float64]) -> str
346+
1. (arg0: numpy.typing.NDArray[numpy.int32]) -> str
347+
2. (arg0: numpy.typing.NDArray[numpy.float64]) -> str
348348
349349
Invoked with: """
350350

@@ -528,7 +528,7 @@ def test_index_using_ellipsis():
528528
],
529529
)
530530
def test_format_descriptors_for_floating_point_types(test_func):
531-
assert "numpy.ndarray[numpy.float" in test_func.__doc__
531+
assert "typing.Annotated[numpy.typing.ArrayLike, numpy.float" in test_func.__doc__
532532

533533

534534
@pytest.mark.parametrize("forcecast", [False, True])

tests/test_numpy_dtypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_complex_array():
373373
def test_signature(doc):
374374
assert (
375375
doc(m.create_rec_nested)
376-
== "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]"
376+
== "create_rec_nested(arg0: int) -> numpy.typing.NDArray[NestedStruct]"
377377
)
378378

379379

tests/test_numpy_vectorize.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_docs(doc):
150150
assert (
151151
doc(m.vectorized_func)
152152
== """
153-
vectorized_func(arg0: numpy.ndarray[numpy.int32], arg1: numpy.ndarray[numpy.float32], arg2: numpy.ndarray[numpy.float64]) -> object
153+
vectorized_func(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.int32], arg1: typing.Annotated[numpy.typing.ArrayLike, numpy.float32], arg2: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]) -> object
154154
"""
155155
)
156156

@@ -212,12 +212,12 @@ def test_passthrough_arguments(doc):
212212
+ ", ".join(
213213
[
214214
"arg0: float",
215-
"arg1: numpy.ndarray[numpy.float64]",
216-
"arg2: numpy.ndarray[numpy.float64]",
217-
"arg3: numpy.ndarray[numpy.int32]",
215+
"arg1: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]",
216+
"arg2: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]",
217+
"arg3: typing.Annotated[numpy.typing.ArrayLike, numpy.int32]",
218218
"arg4: int",
219219
"arg5: m.numpy_vectorize.NonPODClass",
220-
"arg6: numpy.ndarray[numpy.float64]",
220+
"arg6: typing.Annotated[numpy.typing.ArrayLike, numpy.float64]",
221221
]
222222
)
223223
+ ") -> object"

0 commit comments

Comments
 (0)