From b31fd11c9c8484ad12e3cb1fa8ab5d8ea7844488 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 29 Dec 2023 16:09:30 -0500 Subject: [PATCH] add PyAnyMethods for binary operators also pow fixes #3709 --- newsfragments/33712.added.md | 1 + src/types/any.rs | 84 ++++++++++++++++++++++++++++++++++++ tests/test_arithmetics.rs | 7 +++ 3 files changed, 92 insertions(+) create mode 100644 newsfragments/33712.added.md diff --git a/newsfragments/33712.added.md b/newsfragments/33712.added.md new file mode 100644 index 00000000000..d7390f77c14 --- /dev/null +++ b/newsfragments/33712.added.md @@ -0,0 +1 @@ +Added methods to `PyAnyMethods` for binary operators (`add`, `sub`, etc.) diff --git a/src/types/any.rs b/src/types/any.rs index edc74879cb6..1cb9a057b9d 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -1208,6 +1208,38 @@ pub trait PyAnyMethods<'py> { where O: ToPyObject; + /// Computes `self + other`. + fn add(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self - other`. + fn sub(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self * other`. + fn mul(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self / other`. + fn div(&self, other: O) -> PyResult> + where + O: ToPyObject; + + /// Computes `self ** other % modulus` (`pow(self, other, modulus)`). + /// `py.None()` may be passed for the `modulus`. + fn pow(&self, other: O1, modulus: O2) -> PyResult> + where + O1: ToPyObject, + O2: ToPyObject; + + /// Computes `self & other`. + fn bitand(&self, other: O) -> PyResult> + where + O: ToPyObject; + /// Determines whether this object appears callable. /// /// This is equivalent to Python's [`callable()`][1] function. @@ -1680,6 +1712,26 @@ pub trait PyAnyMethods<'py> { fn py_super(&self) -> PyResult>; } +macro_rules! implement_binop { + ($name:ident, $c_api:ident, $op:expr) => { + #[doc = concat!("Computes `self ", $op, " other`.")] + fn $name(&self, other: O) -> PyResult> + where + O: ToPyObject, + { + fn inner<'py>( + any: &Bound<'py, PyAny>, + other: Bound<'_, PyAny>, + ) -> PyResult> { + unsafe { ffi::$c_api(any.as_ptr(), other.as_ptr()).assume_owned_or_err(any.py()) } + } + + let py = self.py(); + inner(self, other.to_object(py).into_bound(py)) + } + }; +} + impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { #[inline] fn is(&self, other: &T) -> bool { @@ -1855,6 +1907,38 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> { .and_then(|any| any.is_truthy()) } + implement_binop!(add, PyNumber_Add, "+"); + implement_binop!(sub, PyNumber_Subtract, "-"); + implement_binop!(mul, PyNumber_Multiply, "*"); + implement_binop!(div, PyNumber_TrueDivide, "/"); + implement_binop!(bitand, PyNumber_And, "&"); + + /// Computes `self ** other % modulus` (`pow(self, other, modulus)`). + /// `py.None()` may be passed for the `modulus`. + fn pow(&self, other: O1, modulus: O2) -> PyResult> + where + O1: ToPyObject, + O2: ToPyObject, + { + fn inner<'py>( + any: &Bound<'py, PyAny>, + other: Bound<'_, PyAny>, + modulus: Bound<'_, PyAny>, + ) -> PyResult> { + unsafe { + ffi::PyNumber_Power(any.as_ptr(), other.as_ptr(), modulus.as_ptr()) + .assume_owned_or_err(any.py()) + } + } + + let py = self.py(); + inner( + self, + other.to_object(py).into_bound(py), + modulus.to_object(py).into_bound(py), + ) + } + fn is_callable(&self) -> bool { unsafe { ffi::PyCallable_Check(self.as_ptr()) != 0 } } diff --git a/tests/test_arithmetics.rs b/tests/test_arithmetics.rs index 86078080176..5da2ede8c81 100644 --- a/tests/test_arithmetics.rs +++ b/tests/test_arithmetics.rs @@ -233,6 +233,13 @@ fn binary_arithmetic() { py_expect_exception!(py, c, "1 ** c", PyTypeError); py_run!(py, c, "assert pow(c, 1, 100) == 'BA ** 1 (mod: Some(100))'"); + + let c: Bound<'_, PyAny> = c.extract().unwrap(); + assert_eq!(c.add(&c).unwrap().extract::<&str>().unwrap(), "BA + BA"); + assert_eq!( + c.pow(&c, py.None()).unwrap().extract::<&str>().unwrap(), + "BA ** BA (mod: None)" + ); }); }