Skip to content

Commit

Permalink
Make PyDateTime_IMPORT FFI wrapper thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Oct 15, 2024
1 parent 29c6f4b commit 76f547e
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions pyo3-ffi/src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::cell::UnsafeCell;
use std::os::raw::c_char;
use std::os::raw::c_int;
use std::ptr;
use std::sync::Once;
#[cfg(not(PyPy))]
use {crate::PyCapsule_Import, std::ffi::CString};
#[cfg(not(any(PyPy, GraalPy)))]
Expand Down Expand Up @@ -600,7 +601,7 @@ unsafe impl Sync for PyDateTime_CAPI {}
/// `PyDateTime_IMPORT` is called
#[inline]
pub unsafe fn PyDateTimeAPI() -> *mut PyDateTime_CAPI {
*PyDateTimeAPI_impl.0.get()
*PyDateTimeAPI_impl.ptr.get()
}

#[inline]
Expand All @@ -610,20 +611,27 @@ pub unsafe fn PyDateTime_TimeZone_UTC() -> *mut PyObject {

/// Populates the `PyDateTimeAPI` object
pub unsafe fn PyDateTime_IMPORT() {
// PyPy expects the C-API to be initialized via PyDateTime_Import, so trying to use
// `PyCapsule_Import` will behave unexpectedly in pypy.
#[cfg(PyPy)]
let py_datetime_c_api = PyDateTime_Import();

#[cfg(not(PyPy))]
let py_datetime_c_api = {
// PyDateTime_CAPSULE_NAME is a macro in C
let PyDateTime_CAPSULE_NAME = CString::new("datetime.datetime_CAPI").unwrap();

PyCapsule_Import(PyDateTime_CAPSULE_NAME.as_ptr(), 1) as *mut PyDateTime_CAPI
};

*PyDateTimeAPI_impl.0.get() = py_datetime_c_api;
if !PyDateTimeAPI_impl.once.is_completed() {
// PyPy expects the C-API to be initialized via PyDateTime_Import, so trying to use
// `PyCapsule_Import` will behave unexpectedly in pypy.
#[cfg(PyPy)]
let py_datetime_c_api = PyDateTime_Import();

#[cfg(not(PyPy))]
let py_datetime_c_api = {
// PyDateTime_CAPSULE_NAME is a macro in C
let PyDateTime_CAPSULE_NAME = CString::new("datetime.datetime_CAPI").unwrap();

PyCapsule_Import(PyDateTime_CAPSULE_NAME.as_ptr(), 1) as *mut PyDateTime_CAPI
};

// Protect against race conditions when the datetime API is concurrently
// initialized in multiple threads. UnsafeCell.get() cannot panic so this
// won't panic either.
PyDateTimeAPI_impl.once.call_once(|| {
*PyDateTimeAPI_impl.ptr.get() = py_datetime_c_api;
});
}
}

// skipped non-limited PyDateTime_TimeZone_UTC
Expand Down Expand Up @@ -739,8 +747,13 @@ extern "C" {

// Rust specific implementation details

struct PyDateTimeAPISingleton(UnsafeCell<*mut PyDateTime_CAPI>);
struct PyDateTimeAPISingleton {
once: Once,
ptr: UnsafeCell<*mut PyDateTime_CAPI>,
}
unsafe impl Sync for PyDateTimeAPISingleton {}

static PyDateTimeAPI_impl: PyDateTimeAPISingleton =
PyDateTimeAPISingleton(UnsafeCell::new(ptr::null_mut()));
static PyDateTimeAPI_impl: PyDateTimeAPISingleton = PyDateTimeAPISingleton {
once: Once::new(),
ptr: UnsafeCell::new(ptr::null_mut()),
};

0 comments on commit 76f547e

Please sign in to comment.