Skip to content

Commit 76f547e

Browse files
committed
Make PyDateTime_IMPORT FFI wrapper thread-safe
1 parent 29c6f4b commit 76f547e

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

pyo3-ffi/src/datetime.rs

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::cell::UnsafeCell;
1212
use std::os::raw::c_char;
1313
use std::os::raw::c_int;
1414
use std::ptr;
15+
use std::sync::Once;
1516
#[cfg(not(PyPy))]
1617
use {crate::PyCapsule_Import, std::ffi::CString};
1718
#[cfg(not(any(PyPy, GraalPy)))]
@@ -600,7 +601,7 @@ unsafe impl Sync for PyDateTime_CAPI {}
600601
/// `PyDateTime_IMPORT` is called
601602
#[inline]
602603
pub unsafe fn PyDateTimeAPI() -> *mut PyDateTime_CAPI {
603-
*PyDateTimeAPI_impl.0.get()
604+
*PyDateTimeAPI_impl.ptr.get()
604605
}
605606

606607
#[inline]
@@ -610,20 +611,27 @@ pub unsafe fn PyDateTime_TimeZone_UTC() -> *mut PyObject {
610611

611612
/// Populates the `PyDateTimeAPI` object
612613
pub unsafe fn PyDateTime_IMPORT() {
613-
// PyPy expects the C-API to be initialized via PyDateTime_Import, so trying to use
614-
// `PyCapsule_Import` will behave unexpectedly in pypy.
615-
#[cfg(PyPy)]
616-
let py_datetime_c_api = PyDateTime_Import();
617-
618-
#[cfg(not(PyPy))]
619-
let py_datetime_c_api = {
620-
// PyDateTime_CAPSULE_NAME is a macro in C
621-
let PyDateTime_CAPSULE_NAME = CString::new("datetime.datetime_CAPI").unwrap();
622-
623-
PyCapsule_Import(PyDateTime_CAPSULE_NAME.as_ptr(), 1) as *mut PyDateTime_CAPI
624-
};
625-
626-
*PyDateTimeAPI_impl.0.get() = py_datetime_c_api;
614+
if !PyDateTimeAPI_impl.once.is_completed() {
615+
// PyPy expects the C-API to be initialized via PyDateTime_Import, so trying to use
616+
// `PyCapsule_Import` will behave unexpectedly in pypy.
617+
#[cfg(PyPy)]
618+
let py_datetime_c_api = PyDateTime_Import();
619+
620+
#[cfg(not(PyPy))]
621+
let py_datetime_c_api = {
622+
// PyDateTime_CAPSULE_NAME is a macro in C
623+
let PyDateTime_CAPSULE_NAME = CString::new("datetime.datetime_CAPI").unwrap();
624+
625+
PyCapsule_Import(PyDateTime_CAPSULE_NAME.as_ptr(), 1) as *mut PyDateTime_CAPI
626+
};
627+
628+
// Protect against race conditions when the datetime API is concurrently
629+
// initialized in multiple threads. UnsafeCell.get() cannot panic so this
630+
// won't panic either.
631+
PyDateTimeAPI_impl.once.call_once(|| {
632+
*PyDateTimeAPI_impl.ptr.get() = py_datetime_c_api;
633+
});
634+
}
627635
}
628636

629637
// skipped non-limited PyDateTime_TimeZone_UTC
@@ -739,8 +747,13 @@ extern "C" {
739747

740748
// Rust specific implementation details
741749

742-
struct PyDateTimeAPISingleton(UnsafeCell<*mut PyDateTime_CAPI>);
750+
struct PyDateTimeAPISingleton {
751+
once: Once,
752+
ptr: UnsafeCell<*mut PyDateTime_CAPI>,
753+
}
743754
unsafe impl Sync for PyDateTimeAPISingleton {}
744755

745-
static PyDateTimeAPI_impl: PyDateTimeAPISingleton =
746-
PyDateTimeAPISingleton(UnsafeCell::new(ptr::null_mut()));
756+
static PyDateTimeAPI_impl: PyDateTimeAPISingleton = PyDateTimeAPISingleton {
757+
once: Once::new(),
758+
ptr: UnsafeCell::new(ptr::null_mut()),
759+
};

0 commit comments

Comments
 (0)