Skip to content

Commit 77af339

Browse files
committed
Significant rewrite to avoid using thread_locals as much as possible.
Since we can avoid them by checking this atomic, the cmake config conditional shouldn't be necessary. The slower path (with thread_locals and extra checks) only comes in when a second interpreter is actually instanciated.
1 parent 9fbd36c commit 77af339

File tree

8 files changed

+166
-151
lines changed

8 files changed

+166
-151
lines changed

CMakeLists.txt

-4
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ option(PYBIND11_DISABLE_HANDLE_TYPE_NAME_DEFAULT_IMPLEMENTATION
9292
"To enforce that a handle_type_name<> specialization exists" OFF)
9393
option(PYBIND11_SIMPLE_GIL_MANAGEMENT
9494
"Use simpler GIL management logic that does not support disassociation" OFF)
95-
option(PYBIND11_SUBINTERPRETER_SUPPORT "Enable support for sub-interpreters" OFF)
9695
option(PYBIND11_NUMPY_1_ONLY
9796
"Disable NumPy 2 support to avoid changes to previous pybind11 versions." OFF)
9897
set(PYBIND11_INTERNALS_VERSION
@@ -106,9 +105,6 @@ endif()
106105
if(PYBIND11_SIMPLE_GIL_MANAGEMENT)
107106
add_compile_definitions(PYBIND11_SIMPLE_GIL_MANAGEMENT)
108107
endif()
109-
if(PYBIND11_SUBINTERPRETER_SUPPORT)
110-
add_compile_definitions(PYBIND11_SUBINTERPRETER_SUPPORT)
111-
endif()
112108
if(PYBIND11_NUMPY_1_ONLY)
113109
add_compile_definitions(PYBIND11_NUMPY_1_ONLY)
114110
endif()

include/pybind11/detail/common.h

+8
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,13 @@
291291
# define PYBIND11_ASSERT_GIL_HELD_INCREF_DECREF
292292
#endif
293293

294+
// Slightly faster code paths are available when this is NOT defined, so undefine it for impls
295+
// that do not have subinterpreter. Nothing breaks if this is defined but the impl does not
296+
// actually support subinterpreters.
297+
#if PY_VERSION_HEX >= 0x030C0000 && !defined(PYPY_VERSION) && !defined(GRAALVM_PYTHON)
298+
# define PYBIND11_SUBINTERPRETER_SUPPORT
299+
#endif
300+
294301
// #define PYBIND11_STR_LEGACY_PERMISSIVE
295302
// If DEFINED, pybind11::str can hold PyUnicodeObject or PyBytesObject
296303
// (probably surprising and never documented, but this was the
@@ -466,6 +473,7 @@ PYBIND11_WARNING_DISABLE_CLANG("-Wgnu-zero-variadic-macro-arguments")
466473
return m.ptr(); \
467474
} \
468475
int PYBIND11_CONCAT(pybind11_exec_, name)(PyObject * pm) { \
476+
pybind11::detail::get_interpreter_count()++; \
469477
try { \
470478
auto m = pybind11::reinterpret_borrow<::pybind11::module_>(pm); \
471479
PYBIND11_CONCAT(pybind11_init_, name)(m); \

include/pybind11/detail/internals.h

+126-106
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <pybind11/conduit/pybind11_platform_abi_id.h>
1919
#include <pybind11/pytypes.h>
2020

21+
#include <atomic>
2122
#include <exception>
2223
#include <mutex>
2324
#include <thread>
@@ -257,29 +258,46 @@ struct type_info {
257258
"__pybind11_module_local_v" PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) \
258259
PYBIND11_COMPILER_TYPE_LEADING_UNDERSCORE PYBIND11_PLATFORM_ABI_ID "__"
259260

261+
inline PyThreadState *get_thread_state_unchecked() {
262+
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)
263+
return PyThreadState_GET();
264+
#elif PY_VERSION_HEX < 0x030D0000
265+
return _PyThreadState_UncheckedGet();
266+
#else
267+
return PyThreadState_GetUnchecked();
268+
#endif
269+
}
270+
271+
/// We use this count to figure out if there are or have been multiple sub-interpreters active at
272+
/// any point. This must never decrease while any interpreter may be running in any thread!
273+
inline std::atomic<int> &get_interpreter_count() {
274+
static std::atomic<int> counter(0);
275+
return counter;
276+
}
277+
260278
/// Each module locally stores a pointer to the `internals` data. The data
261279
/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
262280
inline internals **&get_internals_pp() {
263-
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON) || PY_VERSION_HEX < 0x030C0000 \
264-
|| !defined(PYBIND11_SUBINTERPRETER_SUPPORT)
265-
static internals **internals_pp = nullptr;
266-
#else
267-
static thread_local internals **internals_pp = nullptr;
268-
// This is one per interpreter, we cache it but if the thread changed
269-
// then we need to invalidate our cache
270-
// the caller will find the right value and set it if its null
271-
static thread_local PyThreadState *tstate_cached = nullptr;
272-
# if PY_VERSION_HEX < 0x030D0000
273-
PyThreadState *tstate = _PyThreadState_UncheckedGet();
274-
# else
275-
PyThreadState *tstate = PyThreadState_GetUnchecked();
276-
# endif
277-
if (tstate != tstate_cached) {
278-
tstate_cached = tstate;
279-
internals_pp = nullptr;
281+
#ifdef PYBIND11_SUBINTERPRETER_SUPPORT
282+
if (get_interpreter_count() > 1) {
283+
// Internals is one per interpreter. When multiple interpreters are alive in different
284+
// threads we have to allow them to have different internals, so we need a thread_local.
285+
static thread_local internals **t_internals_pp = nullptr;
286+
// Whenever the interpreter changes we need to invalidate the internals_pp. That is slow,
287+
// so we only do it when the PyThreadState has changed, which indicates the interpreter
288+
// might have changed as well.
289+
static thread_local PyThreadState *tstate_cached = nullptr;
290+
auto *tstate = get_thread_state_unchecked();
291+
if (tstate != tstate_cached) {
292+
tstate_cached = tstate;
293+
// the caller will fetch the instance from the state dict or create a new one
294+
t_internals_pp = nullptr;
295+
}
296+
return t_internals_pp;
280297
}
281298
#endif
282-
return internals_pp;
299+
static internals **s_internals_pp = nullptr;
300+
return s_internals_pp;
283301
}
284302

285303
// forward decl
@@ -410,20 +428,6 @@ inline object get_python_state_dict() {
410428
return state_dict;
411429
}
412430

413-
inline object get_internals_obj_from_state_dict(handle state_dict) {
414-
return reinterpret_steal<object>(
415-
dict_getitemstringref(state_dict.ptr(), PYBIND11_INTERNALS_ID));
416-
}
417-
418-
inline internals **get_internals_pp_from_capsule(handle obj) {
419-
void *raw_ptr = PyCapsule_GetPointer(obj.ptr(), /*name=*/nullptr);
420-
if (raw_ptr == nullptr) {
421-
raise_from(PyExc_SystemError, "pybind11::detail::get_internals_pp_from_capsule() FAILED");
422-
throw error_already_set();
423-
}
424-
return static_cast<internals **>(raw_ptr);
425-
}
426-
427431
inline uint64_t round_up_to_next_pow2(uint64_t x) {
428432
// Round-up to the next power of two.
429433
// See https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
@@ -438,29 +442,51 @@ inline uint64_t round_up_to_next_pow2(uint64_t x) {
438442
return x;
439443
}
440444

445+
#if defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
446+
using internals_safe_gil_scoped_acquire = gil_scoped_acquire;
447+
#else
448+
// Cannot use py::gil_scoped_acquire inside get_internals since that calls get_internals.
449+
struct internals_safe_gil_scoped_acquire {
450+
internals_safe_gil_scoped_acquire() : state(PyGILState_Ensure()) {}
451+
internals_safe_gil_scoped_acquire(const internals_safe_gil_scoped_acquire &) = delete;
452+
internals_safe_gil_scoped_acquire &operator=(const internals_safe_gil_scoped_acquire &)
453+
= delete;
454+
internals_safe_gil_scoped_acquire(internals_safe_gil_scoped_acquire &&) = delete;
455+
internals_safe_gil_scoped_acquire &operator=(internals_safe_gil_scoped_acquire &&) = delete;
456+
~internals_safe_gil_scoped_acquire() { PyGILState_Release(state); }
457+
const PyGILState_STATE state;
458+
};
459+
#endif
460+
461+
template <typename InternalsType>
462+
inline InternalsType **find_internals_pp(char const *state_dict_key) {
463+
dict state_dict = get_python_state_dict();
464+
auto internals_obj
465+
= reinterpret_steal<object>(dict_getitemstringref(state_dict.ptr(), state_dict_key));
466+
if (internals_obj) {
467+
void *raw_ptr = PyCapsule_GetPointer(internals_obj.ptr(), /*name=*/nullptr);
468+
if (!raw_ptr) {
469+
pybind11_fail("find_or_create_internals_pp: broken capsule!");
470+
} else {
471+
return reinterpret_cast<InternalsType **>(raw_ptr);
472+
}
473+
}
474+
return nullptr;
475+
}
476+
441477
/// Return a reference to the current `internals` data
442478
PYBIND11_NOINLINE internals &get_internals() {
443-
auto **&internals_pp = get_internals_pp();
479+
auto &internals_pp = get_internals_pp();
444480
if (internals_pp && *internals_pp) {
445481
return **internals_pp;
446482
}
447483

448-
// Ensure that the GIL is held since we will need to make Python calls.
449-
// Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
450-
struct gil_scoped_acquire_local {
451-
gil_scoped_acquire_local() : state(PyGILState_Ensure()) {}
452-
gil_scoped_acquire_local(const gil_scoped_acquire_local &) = delete;
453-
gil_scoped_acquire_local &operator=(const gil_scoped_acquire_local &) = delete;
454-
~gil_scoped_acquire_local() { PyGILState_Release(state); }
455-
const PyGILState_STATE state;
456-
} gil;
484+
internals_safe_gil_scoped_acquire gil;
457485

458486
error_scope err_scope;
459487

460-
dict state_dict = get_python_state_dict();
461-
if (object internals_obj = get_internals_obj_from_state_dict(state_dict)) {
462-
internals_pp = get_internals_pp_from_capsule(internals_obj);
463-
}
488+
internals_pp = find_internals_pp<internals>(PYBIND11_INTERNALS_ID);
489+
464490
if (internals_pp && *internals_pp) {
465491
// We loaded the internals through `state_dict`, which means that our `error_already_set`
466492
// and `builtin_exception` may be different local classes than the ones set up in the
@@ -479,8 +505,11 @@ PYBIND11_NOINLINE internals &get_internals() {
479505
#endif
480506
} else {
481507
if (!internals_pp) {
482-
internals_pp = new internals *();
508+
internals_pp = new internals *(nullptr);
509+
dict state = get_python_state_dict();
510+
state[PYBIND11_INTERNALS_ID] = capsule(reinterpret_cast<void *>(internals_pp));
483511
}
512+
484513
auto *&internals_ptr = *internals_pp;
485514
internals_ptr = new internals();
486515

@@ -498,7 +527,6 @@ PYBIND11_NOINLINE internals &get_internals() {
498527
}
499528

500529
internals_ptr->istate = tstate->interp;
501-
state_dict[PYBIND11_INTERNALS_ID] = capsule(reinterpret_cast<void *>(internals_pp));
502530
internals_ptr->registered_exception_translators.push_front(&translate_exception);
503531
internals_ptr->static_property_type = make_static_property_type();
504532
internals_ptr->default_metaclass = make_default_metaclass();
@@ -528,69 +556,61 @@ struct local_internals {
528556
std::forward_list<ExceptionTranslator> registered_exception_translators;
529557
};
530558

559+
inline local_internals **&get_local_internals_pp() {
560+
#ifdef PYBIND11_SUBINTERPRETER_SUPPORT
561+
if (get_interpreter_count() > 1) {
562+
// Internals is one per interpreter. When multiple interpreters are alive in different
563+
// threads we have to allow them to have different internals, so we need a thread_local.
564+
static thread_local local_internals **t_internals_pp = nullptr;
565+
// Whenever the interpreter changes we need to invalidate the internals_pp. That is slow,
566+
// so we only do it when the PyThreadState has changed, which indicates the interpreter
567+
// might have changed as well.
568+
static thread_local PyThreadState *tstate_cached = nullptr;
569+
auto *tstate = get_thread_state_unchecked();
570+
if (tstate != tstate_cached) {
571+
tstate_cached = tstate;
572+
// the caller will fetch the instance from the state dict or create a new one
573+
t_internals_pp = nullptr;
574+
}
575+
return t_internals_pp;
576+
}
577+
#endif
578+
static local_internals **s_internals_pp = nullptr;
579+
return s_internals_pp;
580+
}
581+
582+
/// A string key uniquely describing this module
583+
inline char const *get_local_internals_id() {
584+
// Use the address of this static itself as part of the key, so that the value is uniquely tied
585+
// to where the module is loaded in memory
586+
static const std::string this_module_idstr
587+
= PYBIND11_MODULE_LOCAL_ID
588+
+ std::to_string(reinterpret_cast<uintptr_t>(&this_module_idstr));
589+
return this_module_idstr.c_str();
590+
}
591+
531592
/// Works like `get_internals`, but for things which are locally registered.
532593
inline local_internals &get_local_internals() {
533-
// Current static can be created in the interpreter finalization routine. If the later will be
534-
// destroyed in another static variable destructor, creation of this static there will cause
535-
// static deinitialization fiasco. In order to avoid it we avoid destruction of the
536-
// local_internals static. One can read more about the problem and current solution here:
537-
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
538-
539-
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON) || PY_VERSION_HEX < 0x030C0000 \
540-
|| !defined(PYBIND11_SUBINTERPRETER_SUPPORT)
541-
static auto *locals = new local_internals();
542-
#else
543-
static thread_local local_internals *locals = nullptr;
544-
// This is one per interpreter, we cache it but if the interpreter changed
545-
// then we need to invalidate our cache and re-fetch from the state dict
546-
static thread_local PyThreadState *tstate_cached = nullptr;
547-
# if PY_VERSION_HEX < 0x030D0000
548-
PyThreadState *tstate = _PyThreadState_UncheckedGet();
549-
# else
550-
PyThreadState *tstate = PyThreadState_GetUnchecked();
551-
# endif
552-
if (!tstate) {
553-
pybind11_fail(
554-
"pybind11::detail::get_local_internals() called without a current python thread");
594+
auto &local_internals_pp = get_local_internals_pp();
595+
if (local_internals_pp && *local_internals_pp) {
596+
return **local_internals_pp;
555597
}
556-
if (tstate != tstate_cached) {
557-
// we create a unique value at first run which is based on a pointer to
558-
// a (non-thread_local) static value in this function, then multiple
559-
// loaded modules using this code will still each have a unique key.
560-
static const std::string this_module_idstr
561-
= PYBIND11_MODULE_LOCAL_ID
562-
+ std::to_string(reinterpret_cast<uintptr_t>(&this_module_idstr));
563-
564-
// Ensure that the GIL is held since we will need to make Python calls.
565-
// Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
566-
struct gil_scoped_acquire_local {
567-
gil_scoped_acquire_local() : state(PyGILState_Ensure()) {}
568-
gil_scoped_acquire_local(const gil_scoped_acquire_local &) = delete;
569-
gil_scoped_acquire_local &operator=(const gil_scoped_acquire_local &) = delete;
570-
~gil_scoped_acquire_local() { PyGILState_Release(state); }
571-
const PyGILState_STATE state;
572-
} gil;
573-
574-
error_scope err_scope;
575-
dict state_dict = get_python_state_dict();
576-
object local_capsule = reinterpret_steal<object>(
577-
dict_getitemstringref(state_dict.ptr(), this_module_idstr.c_str()));
578-
if (!local_capsule) {
579-
locals = new local_internals();
580-
state_dict[this_module_idstr.c_str()] = capsule(reinterpret_cast<void *>(locals));
581-
} else {
582-
void *ptr = PyCapsule_GetPointer(local_capsule.ptr(), nullptr);
583-
if (!ptr) {
584-
raise_from(PyExc_SystemError, "pybind11::detail::get_local_internals() FAILED");
585-
throw error_already_set();
586-
}
587-
locals = reinterpret_cast<local_internals *>(ptr);
588-
}
589-
tstate_cached = tstate;
598+
599+
internals_safe_gil_scoped_acquire gil;
600+
601+
error_scope err_scope;
602+
603+
local_internals_pp = find_internals_pp<local_internals>(get_local_internals_id());
604+
if (!local_internals_pp) {
605+
local_internals_pp = new local_internals *(nullptr);
606+
dict state = get_python_state_dict();
607+
state[get_local_internals_id()] = capsule(reinterpret_cast<void *>(local_internals_pp));
608+
}
609+
if (!*local_internals_pp) {
610+
*local_internals_pp = new local_internals();
590611
}
591-
#endif
592612

593-
return *locals;
613+
return **local_internals_pp;
594614
}
595615

596616
#ifdef Py_GIL_DISABLED

include/pybind11/detail/type_caster_base.h

-10
Original file line numberDiff line numberDiff line change
@@ -497,16 +497,6 @@ PYBIND11_NOINLINE handle get_object_handle(const void *ptr, const detail::type_i
497497
});
498498
}
499499

500-
inline PyThreadState *get_thread_state_unchecked() {
501-
#if defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)
502-
return PyThreadState_GET();
503-
#elif PY_VERSION_HEX < 0x030D0000
504-
return _PyThreadState_UncheckedGet();
505-
#else
506-
return PyThreadState_GetUnchecked();
507-
#endif
508-
}
509-
510500
// Forward declarations
511501
void keep_alive_impl(handle nurse, handle patient);
512502
inline PyObject *make_new_instance(PyTypeObject *type);

0 commit comments

Comments
 (0)