Skip to content

Commit 705efcc

Browse files
sebergrwgkhenryiii
authored
feat: make numpy.h compatible with both NumPy 1.x and 2.x (#5050)
* API: Make `numpy.h` compatible with both NumPy 1.x and 2.x * TST: Update numpy dtype flags test to not covert flags to char * API: Add `numpy2.h` instead and make `numpy.h` safe This means that users of `numpy.h` cannot be broken, but need to update to `numpy2.h` if they want to compile for NumPy 2. Using Macros simply and didn't bother to try to remove unnecessary code paths. * API: Rather than `numpy2.h` use a define for the user. * Thread `PYBIND11_NUMPY2_SUPPORT` through things and try to adept test matrix * Small fixups (shouldn't matter)? * Fixup. Does upgrading scipy help? (it shouldn't?) (Some other small fixup) * Use NumPy 2 nightlies for ubuntu-latest job also * BUG: Fix numpy.bool check * TST: Fix complexwarning * BUG: Fix the fact that only the 50 slot is filled with the copy alias (There were 3 functions all doing the same, only this slot survived 2.x) * TST: One more test tweak * TST: Use "long" name for long, since it changed on windows * TST: Apparently we didn't always have ulong, so just use `L` * TST: Enforce dtype='l' for test as default isn't long anymore on windows * Rename macro and invert logic to PYBIND11_NUMPY_1_ONLY * PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED * Test and code comment expansion * CI: Use pre-releases of numpy/scipy from pip via explicit version * CI: NumPy 2 only available on almalinux (as it is Python >=3.9) * MAINT: Match name more exactly and adopt error phrasing * MAINT: Pushed early, move helper to be private member * fix error message compilation when using NumPy 1.x-only backcompat * silence name shadowing warning * chore: minor optimization Signed-off-by: Henry Schreiner <[email protected]> --------- Signed-off-by: Henry Schreiner <[email protected]> Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]> Co-authored-by: Henry Schreiner <[email protected]>
1 parent e0f2c71 commit 705efcc

File tree

11 files changed

+206
-21
lines changed

11 files changed

+206
-21
lines changed

.github/workflows/ci.yml

+12-1
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,14 @@ jobs:
108108
run: python -m pip install pytest-github-actions-annotate-failures
109109

110110
# First build - C++11 mode and inplace
111-
# More-or-less randomly adding -DPYBIND11_SIMPLE_GIL_MANAGEMENT=ON here.
111+
# More-or-less randomly adding -DPYBIND11_SIMPLE_GIL_MANAGEMENT=ON here
112+
# (same for PYBIND11_NUMPY_1_ONLY, but requires a NumPy 1.x at runtime).
112113
- name: Configure C++11 ${{ matrix.args }}
113114
run: >
114115
cmake -S . -B .
115116
-DPYBIND11_WERROR=ON
116117
-DPYBIND11_SIMPLE_GIL_MANAGEMENT=ON
118+
-DPYBIND11_NUMPY_1_ONLY=ON
117119
-DDOWNLOAD_CATCH=ON
118120
-DDOWNLOAD_EIGEN=ON
119121
-DCMAKE_CXX_STANDARD=11
@@ -138,11 +140,13 @@ jobs:
138140

139141
# Second build - C++17 mode and in a build directory
140142
# More-or-less randomly adding -DPYBIND11_SIMPLE_GIL_MANAGEMENT=OFF here.
143+
# (same for PYBIND11_NUMPY_1_ONLY, but requires a NumPy 1.x at runtime).
141144
- name: Configure C++17
142145
run: >
143146
cmake -S . -B build2
144147
-DPYBIND11_WERROR=ON
145148
-DPYBIND11_SIMPLE_GIL_MANAGEMENT=OFF
149+
-DPYBIND11_NUMPY_1_ONLY=ON
146150
-DDOWNLOAD_CATCH=ON
147151
-DDOWNLOAD_EIGEN=ON
148152
-DCMAKE_CXX_STANDARD=17
@@ -660,6 +664,11 @@ jobs:
660664
run: |
661665
python3 -m pip install cmake -r tests/requirements.txt
662666
667+
- name: Ensure NumPy 2 is used (required Python >= 3.9)
668+
if: matrix.container == 'almalinux:9'
669+
run: |
670+
python3 -m pip install 'numpy>=2.0.0b1' 'scipy>=1.13.0rc1'
671+
663672
- name: Configure
664673
shell: bash
665674
run: >
@@ -895,8 +904,10 @@ jobs:
895904
python-version: ${{ matrix.python }}
896905

897906
- name: Prepare env
907+
# Ensure use of NumPy 2 (via NumPy nightlies but can be changed soon)
898908
run: |
899909
python3 -m pip install -r tests/requirements.txt
910+
python3 -m pip install 'numpy>=2.0.0b1' 'scipy>=1.13.0rc1'
900911
901912
- name: Update CMake
902913
uses: jwlawson/[email protected]

CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,18 @@ option(PYBIND11_TEST "Build pybind11 test suite?" ${PYBIND11_MASTER_PROJECT})
109109
option(PYBIND11_NOPYTHON "Disable search for Python" OFF)
110110
option(PYBIND11_SIMPLE_GIL_MANAGEMENT
111111
"Use simpler GIL management logic that does not support disassociation" OFF)
112+
option(PYBIND11_NUMPY_1_ONLY
113+
"Disable NumPy 2 support to avoid changes to previous pybind11 versions." OFF)
112114
set(PYBIND11_INTERNALS_VERSION
113115
""
114116
CACHE STRING "Override the ABI version, may be used to enable the unstable ABI.")
115117

116118
if(PYBIND11_SIMPLE_GIL_MANAGEMENT)
117119
add_compile_definitions(PYBIND11_SIMPLE_GIL_MANAGEMENT)
118120
endif()
121+
if(PYBIND11_NUMPY_1_ONLY)
122+
add_compile_definitions(PYBIND11_NUMPY_1_ONLY)
123+
endif()
119124

120125
cmake_dependent_option(
121126
USE_PYTHON_INCLUDE_DIR

include/pybind11/cast.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,9 @@ class type_caster<bool> {
327327
value = false;
328328
return true;
329329
}
330-
if (convert || (std::strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name) == 0)) {
331-
// (allow non-implicit conversion for numpy booleans)
330+
if (convert || is_numpy_bool(src)) {
331+
// (allow non-implicit conversion for numpy booleans), use strncmp
332+
// since NumPy 1.x had an additional trailing underscore.
332333

333334
Py_ssize_t res = -1;
334335
if (src.is_none()) {
@@ -360,6 +361,15 @@ class type_caster<bool> {
360361
return handle(src ? Py_True : Py_False).inc_ref();
361362
}
362363
PYBIND11_TYPE_CASTER(bool, const_name("bool"));
364+
365+
private:
366+
// Test if an object is a NumPy boolean (without fetching the type).
367+
static inline bool is_numpy_bool(handle object) {
368+
const char *type_name = Py_TYPE(object.ptr())->tp_name;
369+
// Name changed to `numpy.bool` in NumPy 2, `numpy.bool_` is needed for 1.x support
370+
return std::strcmp("numpy.bool", type_name) == 0
371+
|| std::strcmp("numpy.bool_", type_name) == 0;
372+
}
363373
};
364374

365375
// Helper class for UTF-{8,16,32} C++ stl strings:

include/pybind11/detail/common.h

+4
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ PYBIND11_WARNING_DISABLE_MSVC(4505)
296296
# undef copysign
297297
#endif
298298

299+
#if defined(PYBIND11_NUMPY_1_ONLY)
300+
# define PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED
301+
#endif
302+
299303
#if defined(PYPY_VERSION) && !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
300304
# define PYBIND11_SIMPLE_GIL_MANAGEMENT
301305
#endif

include/pybind11/numpy.h

+116-9
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@
2929
#include <utility>
3030
#include <vector>
3131

32+
#if defined(PYBIND11_NUMPY_1_ONLY) && !defined(PYBIND11_INTERNAL_NUMPY_1_ONLY_DETECTED)
33+
# error PYBIND11_NUMPY_1_ONLY must be defined before any pybind11 header is included.
34+
#endif
35+
3236
/* This will be true on all flat address space platforms and allows us to reduce the
3337
whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size
3438
and dimension types (e.g. shape, strides, indexing), instead of inflicting this
35-
upon the library user. */
39+
upon the library user.
40+
Note that NumPy 2 now uses ssize_t for `npy_intp` to simplify this. */
3641
static_assert(sizeof(::pybind11::ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t");
3742
static_assert(std::is_signed<Py_intptr_t>::value, "Py_intptr_t must be signed");
3843
// We now can reinterpret_cast between py::ssize_t and Py_intptr_t (MSVC + PyPy cares)
@@ -53,7 +58,8 @@ struct handle_type_name<array> {
5358
template <typename type, typename SFINAE = void>
5459
struct npy_format_descriptor;
5560

56-
struct PyArrayDescr_Proxy {
61+
/* NumPy 1 proxy (always includes legacy fields) */
62+
struct PyArrayDescr1_Proxy {
5763
PyObject_HEAD
5864
PyObject *typeobj;
5965
char kind;
@@ -68,6 +74,43 @@ struct PyArrayDescr_Proxy {
6874
PyObject *names;
6975
};
7076

77+
#ifndef PYBIND11_NUMPY_1_ONLY
78+
struct PyArrayDescr_Proxy {
79+
PyObject_HEAD
80+
PyObject *typeobj;
81+
char kind;
82+
char type;
83+
char byteorder;
84+
char _former_flags;
85+
int type_num;
86+
/* Additional fields are NumPy version specific. */
87+
};
88+
#else
89+
/* NumPy 1.x only, we can expose all fields */
90+
using PyArrayDescr_Proxy = PyArrayDescr1_Proxy;
91+
#endif
92+
93+
/* NumPy 2 proxy, including legacy fields */
94+
struct PyArrayDescr2_Proxy {
95+
PyObject_HEAD
96+
PyObject *typeobj;
97+
char kind;
98+
char type;
99+
char byteorder;
100+
char _former_flags;
101+
int type_num;
102+
std::uint64_t flags;
103+
ssize_t elsize;
104+
ssize_t alignment;
105+
PyObject *metadata;
106+
Py_hash_t hash;
107+
void *reserved_null[2];
108+
/* The following fields only exist if 0 <= type_num < 2056 */
109+
char *subarray;
110+
PyObject *fields;
111+
PyObject *names;
112+
};
113+
71114
struct PyArray_Proxy {
72115
PyObject_HEAD
73116
char *data;
@@ -131,6 +174,14 @@ PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name
131174
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
132175
int major_version = numpy_version.attr("major").cast<int>();
133176

177+
#ifdef PYBIND11_NUMPY_1_ONLY
178+
if (major_version >= 2) {
179+
throw std::runtime_error(
180+
"This extension was built with PYBIND11_NUMPY_1_ONLY defined, "
181+
"but NumPy 2 is used in this process. For NumPy2 compatibility, "
182+
"this extension needs to be rebuilt without the PYBIND11_NUMPY_1_ONLY define.");
183+
}
184+
#endif
134185
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
135186
became a private module. */
136187
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
@@ -203,6 +254,8 @@ struct npy_api {
203254
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
204255
};
205256

257+
unsigned int PyArray_RUNTIME_VERSION_;
258+
206259
struct PyArray_Dims {
207260
Py_intptr_t *ptr;
208261
int len;
@@ -241,6 +294,7 @@ struct npy_api {
241294
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
242295
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
243296
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
297+
#ifdef PYBIND11_NUMPY_1_ONLY
244298
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
245299
PyObject *,
246300
unsigned char,
@@ -249,6 +303,7 @@ struct npy_api {
249303
Py_intptr_t *,
250304
PyObject **,
251305
PyObject *);
306+
#endif
252307
PyObject *(*PyArray_Squeeze_)(PyObject *);
253308
// Unused. Not removed because that affects ABI of the class.
254309
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
@@ -266,7 +321,8 @@ struct npy_api {
266321
API_PyArray_DescrFromScalar = 57,
267322
API_PyArray_FromAny = 69,
268323
API_PyArray_Resize = 80,
269-
API_PyArray_CopyInto = 82,
324+
// CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
325+
API_PyArray_CopyInto = 50,
270326
API_PyArray_NewCopy = 85,
271327
API_PyArray_NewFromDescr = 94,
272328
API_PyArray_DescrNewFromType = 96,
@@ -275,7 +331,9 @@ struct npy_api {
275331
API_PyArray_View = 137,
276332
API_PyArray_DescrConverter = 174,
277333
API_PyArray_EquivTypes = 182,
334+
#ifdef PYBIND11_NUMPY_1_ONLY
278335
API_PyArray_GetArrayParamsFromObject = 278,
336+
#endif
279337
API_PyArray_SetBaseObject = 282
280338
};
281339

@@ -290,7 +348,8 @@ struct npy_api {
290348
npy_api api;
291349
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
292350
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
293-
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) {
351+
api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
352+
if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
294353
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
295354
}
296355
DECL_NPY_API(PyArray_Type);
@@ -309,7 +368,9 @@ struct npy_api {
309368
DECL_NPY_API(PyArray_View);
310369
DECL_NPY_API(PyArray_DescrConverter);
311370
DECL_NPY_API(PyArray_EquivTypes);
371+
#ifdef PYBIND11_NUMPY_1_ONLY
312372
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
373+
#endif
313374
DECL_NPY_API(PyArray_SetBaseObject);
314375

315376
#undef DECL_NPY_API
@@ -331,6 +392,14 @@ inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
331392
return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
332393
}
333394

395+
inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
396+
return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
397+
}
398+
399+
inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
400+
return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
401+
}
402+
334403
inline bool check_flags(const void *ptr, int flag) {
335404
return (flag == (array_proxy(ptr)->flags & flag));
336405
}
@@ -610,10 +679,32 @@ class dtype : public object {
610679
}
611680

612681
/// Size of the data type in bytes.
682+
#ifdef PYBIND11_NUMPY_1_ONLY
613683
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
684+
#else
685+
ssize_t itemsize() const {
686+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
687+
return detail::array_descriptor1_proxy(m_ptr)->elsize;
688+
}
689+
return detail::array_descriptor2_proxy(m_ptr)->elsize;
690+
}
691+
#endif
614692

615693
/// Returns true for structured data types.
694+
#ifdef PYBIND11_NUMPY_1_ONLY
616695
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
696+
#else
697+
bool has_fields() const {
698+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
699+
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
700+
}
701+
const auto *proxy = detail::array_descriptor2_proxy(m_ptr);
702+
if (proxy->type_num < 0 || proxy->type_num >= 2056) {
703+
return false;
704+
}
705+
return proxy->names != nullptr;
706+
}
707+
#endif
617708

618709
/// Single-character code for dtype's kind.
619710
/// For example, floating point types are 'f' and integral types are 'i'.
@@ -639,11 +730,29 @@ class dtype : public object {
639730
/// Single character for byteorder
640731
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
641732

642-
/// Alignment of the data type
733+
/// Alignment of the data type
734+
#ifdef PYBIND11_NUMPY_1_ONLY
643735
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
736+
#else
737+
ssize_t alignment() const {
738+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
739+
return detail::array_descriptor1_proxy(m_ptr)->alignment;
740+
}
741+
return detail::array_descriptor2_proxy(m_ptr)->alignment;
742+
}
743+
#endif
644744

645-
/// Flags for the array descriptor
745+
/// Flags for the array descriptor
746+
#ifdef PYBIND11_NUMPY_1_ONLY
646747
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
748+
#else
749+
std::uint64_t flags() const {
750+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
751+
return (unsigned char) detail::array_descriptor1_proxy(m_ptr)->flags;
752+
}
753+
return detail::array_descriptor2_proxy(m_ptr)->flags;
754+
}
755+
#endif
647756

648757
private:
649758
static object &_dtype_from_pep3118() {
@@ -810,9 +919,7 @@ class array : public buffer {
810919
}
811920

812921
/// Byte size of a single element
813-
ssize_t itemsize() const {
814-
return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
815-
}
922+
ssize_t itemsize() const { return dtype().itemsize(); }
816923

817924
/// Total number of bytes
818925
ssize_t nbytes() const { return size() * itemsize(); }

tests/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,5 @@ def pytest_report_header(config):
218218
f" {pybind11_tests.cpp_std}"
219219
f" {pybind11_tests.PYBIND11_INTERNALS_ID}"
220220
f" PYBIND11_SIMPLE_GIL_MANAGEMENT={pybind11_tests.PYBIND11_SIMPLE_GIL_MANAGEMENT}"
221+
f" PYBIND11_NUMPY_1_ONLY={pybind11_tests.PYBIND11_NUMPY_1_ONLY}"
221222
)

tests/pybind11_tests.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ PYBIND11_MODULE(pybind11_tests, m) {
9595
#else
9696
false;
9797
#endif
98+
m.attr("PYBIND11_NUMPY_1_ONLY") =
99+
#if defined(PYBIND11_NUMPY_1_ONLY)
100+
true;
101+
#else
102+
false;
103+
#endif
98104

99105
bind_ConstructorStats(m);
100106

tests/test_eigen_matrix.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,9 @@ def test_both_ref_mutators():
608608
def test_nocopy_wrapper():
609609
# get_elem requires a column-contiguous matrix reference, but should be
610610
# callable with other types of matrix (via copying):
611-
int_matrix_colmajor = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], order="F")
611+
int_matrix_colmajor = np.array(
612+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="l", order="F"
613+
)
612614
dbl_matrix_colmajor = np.array(
613615
int_matrix_colmajor, dtype="double", order="F", copy=True
614616
)

0 commit comments

Comments
 (0)