Skip to content

Commit

Permalink
Make PyDateTime_IMPORT FFI wrapper thread-safe (#4623)
Browse files Browse the repository at this point in the history
* Make PyDateTime_IMPORT FFI wrapper thread-safe

* add changelog entry

* add error checking for PyCapsule_Import call
  • Loading branch information
ngoldbaum authored Oct 20, 2024
1 parent 7c39f1c commit 7909eb6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions newsfragments/4623.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* The FFI wrapper for the PyDateTime_IMPORT macro is now thread-safe.
43 changes: 30 additions & 13 deletions pyo3-ffi/src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{PyObject, PyObject_TypeCheck, PyTypeObject, Py_TYPE};
use std::os::raw::c_char;
use std::os::raw::c_int;
use std::ptr;
use std::sync::Once;
use std::{cell::UnsafeCell, ffi::CStr};
#[cfg(not(any(PyPy, GraalPy)))]
use {crate::Py_hash_t, std::os::raw::c_uchar};
Expand Down Expand Up @@ -602,21 +603,32 @@ pub const PyDateTime_CAPSULE_NAME: &CStr = c_str!("datetime.datetime_CAPI");
/// `PyDateTime_IMPORT` is called
#[inline]
pub unsafe fn PyDateTimeAPI() -> *mut PyDateTime_CAPI {
*PyDateTimeAPI_impl.0.get()
*PyDateTimeAPI_impl.ptr.get()
}

/// Populates the `PyDateTimeAPI` object
pub unsafe fn PyDateTime_IMPORT() {
// PyPy expects the C-API to be initialized via PyDateTime_Import, so trying to use
// `PyCapsule_Import` will behave unexpectedly in pypy.
#[cfg(PyPy)]
let py_datetime_c_api = PyDateTime_Import();

#[cfg(not(PyPy))]
let py_datetime_c_api =
PyCapsule_Import(PyDateTime_CAPSULE_NAME.as_ptr(), 1) as *mut PyDateTime_CAPI;
if !PyDateTimeAPI_impl.once.is_completed() {
// PyPy expects the C-API to be initialized via PyDateTime_Import, so trying to use
// `PyCapsule_Import` will behave unexpectedly in pypy.
#[cfg(PyPy)]
let py_datetime_c_api = PyDateTime_Import();

#[cfg(not(PyPy))]
let py_datetime_c_api =
PyCapsule_Import(PyDateTime_CAPSULE_NAME.as_ptr(), 1) as *mut PyDateTime_CAPI;

if py_datetime_c_api.is_null() {
return;
}

*PyDateTimeAPI_impl.0.get() = py_datetime_c_api;
// Protect against race conditions when the datetime API is concurrently
// initialized in multiple threads. UnsafeCell.get() cannot panic so this
// won't panic either.
PyDateTimeAPI_impl.once.call_once(|| {
*PyDateTimeAPI_impl.ptr.get() = py_datetime_c_api;
});
}
}

#[inline]
Expand Down Expand Up @@ -735,8 +747,13 @@ extern "C" {

// Rust specific implementation details

struct PyDateTimeAPISingleton(UnsafeCell<*mut PyDateTime_CAPI>);
struct PyDateTimeAPISingleton {
once: Once,
ptr: UnsafeCell<*mut PyDateTime_CAPI>,
}
unsafe impl Sync for PyDateTimeAPISingleton {}

static PyDateTimeAPI_impl: PyDateTimeAPISingleton =
PyDateTimeAPISingleton(UnsafeCell::new(ptr::null_mut()));
static PyDateTimeAPI_impl: PyDateTimeAPISingleton = PyDateTimeAPISingleton {
once: Once::new(),
ptr: UnsafeCell::new(ptr::null_mut()),
};

0 comments on commit 7909eb6

Please sign in to comment.