From b5a47f926b5b1f60b38795354e239acc647fc77b Mon Sep 17 00:00:00 2001 From: SparrowLii Date: Fri, 21 May 2021 20:28:48 +0800 Subject: [PATCH] Give MathCell arithmetic ops implementations when MathCell is left value --- src/impl_methods.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++ src/impl_ops.rs | 83 +++++++++++++++++++++++++++++++++++++++++++++ src/math_cell.rs | 47 ++++++++++++++++++++++++- 3 files changed, 211 insertions(+), 1 deletion(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 28098f68f..55bd8d154 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2613,6 +2613,88 @@ where } } +impl ArrayBase + where + AE: Copy, + D: Dimension, + S: Data>, +{ + /// Same as `zip_mut_with`, but just when element type is `MathCell`. + #[inline] + pub(crate) fn zip_cell_with(&self, rhs: &ArrayBase, f: F) + where + S2: Data, + E: Dimension, + F: Fn(&AE, &B) -> AE, + { + if rhs.dim.ndim() == 0 { + // Skip broadcast from 0-dim array + self.zip_cell_with_elem(rhs.get_0d(), f); + } else if self.dim.ndim() == rhs.dim.ndim() && self.shape() == rhs.shape() { + self.zip_cell_with_same_shape(rhs, f); + } else { + let rhs_broadcast = rhs.broadcast_unwrap(self.raw_dim()); + self.zip_cell_with_by_rows(&rhs_broadcast, f); + } + } + + /// Same as `zip_mut_with_elem`, but just when element type is `MathCell`. + pub(crate) fn zip_cell_with_elem(&self, rhs_elem: &B, f: F) + where + F: Fn(&AE, &B) -> AE, + { + match self.as_slice_memory_order() { + Some(slc) => slc.iter().for_each(|x| x.set(f(&x.get(), rhs_elem))), + None => { + let v = self.view(); + v.into_elements_base().for_each(|x| x.set(f(&x.get(), rhs_elem))); + } + } + } + + /// Same as `zip_mut_with_shame_shape`, but just when element type is `MathCell`. + pub(crate) fn zip_cell_with_same_shape(&self, rhs: &ArrayBase, f: F) + where + S2: Data, + E: Dimension, + F: Fn(&AE, &B) -> AE, + { + debug_assert_eq!(self.shape(), rhs.shape()); + + if self.dim.strides_equivalent(&self.strides, &rhs.strides) { + if let Some(self_s) = self.as_slice_memory_order() { + if let Some(rhs_s) = rhs.as_slice_memory_order() { + for (s, r) in self_s.iter().zip(rhs_s) { + s.set(f(&s.get(), r)); + } + return; + } + } + } + + // Otherwise, fall back to the outer iter + self.zip_cell_with_by_rows(rhs, f); + } + + /// Same as `zip_mut_with_by_rows`, but just when element type is `MathCell`. + #[inline(always)] + pub(crate) fn zip_cell_with_by_rows(&self, rhs: &ArrayBase, f: F) + where + S2: Data, + E: Dimension, + F: Fn(&AE, &B) -> AE, + { + debug_assert_eq!(self.shape(), rhs.shape()); + debug_assert_ne!(self.ndim(), 0); + + // break the arrays up into their inner rows + let n = self.ndim(); + let dim = self.raw_dim(); + Zip::from(Lanes::new(self.view(), Axis(n - 1))) + .and(Lanes::new(rhs.broadcast_assume(dim), Axis(n - 1))) + .for_each(move |s_row, r_row| Zip::from(s_row).and(r_row).for_each(|a, b| a.set(f(&a.get(), b)))); + } +} /// Transmute from A to B. /// diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 4c255dfff..9a055fd9e 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -226,6 +226,51 @@ impl<'a, A, S, D, B> $trt for &'a ArrayBase self.map(move |elt| elt.clone() $operator x.clone()) } } + +/// Perform elementwise +#[doc=$doc] +/// between `self` and `rhs`, +/// and return the result. +/// +/// `self` must be a view of `MathCell`. +/// +/// If their shapes disagree, `rhs` is broadcast to the shape of `self`. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S, D, E> $trt<&'a ArrayBase> for ArrayView<'a, MathCell, D> + where + A: Copy + $trt, + B: Clone, + S: Data, + D: Dimension, + E: Dimension, +{ + type Output = ArrayView<'a, MathCell, D>; + fn $mth(self, rhs: &ArrayBase) -> Self::Output + { + self.zip_cell_with(rhs, |x, y| x.clone() $operator y.clone()); + self + } +} + +/// Perform elementwise +#[doc=$doc] +/// between `self` and the scalar `x`, +/// and return the result (based on `self`). +/// +/// `self` must be a view of `MathCell`. +impl<'a, A, D, B> $trt for ArrayView<'a, MathCell, D> +where + A: Copy + $trt, + D: Dimension, + B: ScalarOperand, +{ + type Output = ArrayView<'a, MathCell, D>; + fn $mth(self, y: B) -> ArrayView<'a, MathCell, D> { + self.zip_cell_with_elem(&y, |x, y| x.clone() $operator y.clone()); + self + } +} ); ); @@ -287,6 +332,7 @@ impl<'a, S, D> $trt<&'a ArrayBase> for $scalar mod arithmetic_ops { use super::*; use crate::imp_prelude::*; + use crate::MathCell; use num_complex::Complex; use std::ops::*; @@ -429,6 +475,7 @@ mod arithmetic_ops { mod assign_ops { use super::*; use crate::imp_prelude::*; + use crate::MathCell; macro_rules! impl_assign_op { ($trt:ident, $method:ident, $doc:expr) => { @@ -466,6 +513,42 @@ mod assign_ops { }); } } + + #[doc=$doc] + /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. + /// + /// **Panics** if broadcasting isn’t possible. + impl<'a, A, B, S, D, E> $trt<&'a ArrayBase> for ArrayView<'a, MathCell, D> + where + A: Copy + $trt, + B: Clone, + S: Data, + D: Dimension, + E: Dimension, + { + fn $method(&mut self, rhs: &ArrayBase) { + self.zip_cell_with(rhs, |x, y| { + let mut x = x.clone(); + x.$method(y.clone()); + x + }); + } + } + + #[doc=$doc] + impl<'a, A, D> $trt for ArrayView<'a, MathCell, D> + where + A: Copy + ScalarOperand + $trt, + D: Dimension, + { + fn $method(&mut self, rhs: A) { + self.zip_cell_with_elem(&rhs, |x, y| { + let mut x = x.clone(); + x.$method(y.clone()); + x + }); + } + } }; } diff --git a/src/math_cell.rs b/src/math_cell.rs index f0f8da40b..f5edc0897 100644 --- a/src/math_cell.rs +++ b/src/math_cell.rs @@ -3,7 +3,7 @@ use std::cell::Cell; use std::cmp::Ordering; use std::fmt; -use std::ops::{Deref, DerefMut}; +use std::ops::*; /// A transparent wrapper of [`Cell`](std::cell::Cell) which is identical in every way, except /// it will implement arithmetic operators as well. @@ -88,10 +88,34 @@ impl fmt::Debug for MathCell } } +macro_rules! impl_math_cell_op { + ($trt:ident, $op:tt, $mth:ident) => { + impl $trt for MathCell + where A: $trt + { + type Output = MathCell<>::Output>; + fn $mth(self, other: B) -> MathCell<>::Output> { + MathCell::new(self.into_inner() $op other) + } + } + }; +} + +impl_math_cell_op!(Add, +, add); +impl_math_cell_op!(Sub, -, sub); +impl_math_cell_op!(Mul, *, mul); +impl_math_cell_op!(Div, /, div); +impl_math_cell_op!(Rem, %, rem); +impl_math_cell_op!(BitAnd, &, bitand); +impl_math_cell_op!(BitOr, |, bitor); +impl_math_cell_op!(BitXor, ^, bitxor); +impl_math_cell_op!(Shl, <<, shl); +impl_math_cell_op!(Shr, >>, shr); #[cfg(test)] mod tests { use super::MathCell; + use crate::arr1; #[test] fn test_basic() { @@ -99,4 +123,25 @@ mod tests { c.set(1); assert_eq!(c.get(), 1); } + + #[test] + fn test_math_cell_ops() { + let s = [1, 2, 3, 4, 5, 6]; + let mut a = arr1(&s[0..3]); + let b = arr1(&s[3..6]); + // binary_op + assert_eq!(a.cell_view() + &b, arr1(&[5, 7, 9]).cell_view()); + + // binary_op with scalar + assert_eq!(a.cell_view() * 2, arr1(&[10, 14, 18]).cell_view()); + + // unary_op + let mut a_v = a.cell_view(); + a_v /= &b; + assert_eq!(a_v, arr1(&[2, 2, 3]).cell_view()); + + // unary_op with scalar + a_v <<= 1; + assert_eq!(a_v, arr1(&[4, 4, 6]).cell_view()) + } }