Skip to content

Commit b2a7d0b

Browse files
jturner314bluss
authored andcommitted
Generalize lhs scalar ops to more combos of types
This doesn't have a noticeable impact on the results of the `scalar_add_2` and `scalar_add_strided_2` benchmarks.
1 parent 65e13d9 commit b2a7d0b

File tree

1 file changed

+60
-74
lines changed

1 file changed

+60
-74
lines changed

src/impl_ops.rs

+60-74
Original file line numberDiff line numberDiff line change
@@ -166,56 +166,42 @@ impl<'a, A, S, D, B, C> $trt<B> for &'a ArrayBase<S, D>
166166
);
167167
);
168168

169-
// Pick the expression $a for commutative and $b for ordered binop
170-
macro_rules! if_commutative {
171-
(Commute { $a:expr } or { $b:expr }) => {
172-
$a
173-
};
174-
(Ordered { $a:expr } or { $b:expr }) => {
175-
$b
176-
};
177-
}
178-
179169
macro_rules! impl_scalar_lhs_op {
180-
// $commutative flag. Reuse the self + scalar impl if we can.
181-
// We can do this safely since these are the primitive numeric types
182-
($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
183-
// these have no doc -- they are not visible in rustdoc
184-
// Perform elementwise
185-
// between the scalar `self` and array `rhs`,
186-
// and return the result (based on `self`).
187-
impl<S, D> $trt<ArrayBase<S, D>> for $scalar
188-
where S: DataOwned<Elem=$scalar> + DataMut,
189-
D: Dimension,
170+
($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
171+
/// Perform elementwise
172+
#[doc=$doc]
173+
/// between the scalar `self` and array `rhs`,
174+
/// and return the result (based on `self`).
175+
impl<A, S, D> $trt<ArrayBase<S, D>> for $scalar
176+
where
177+
$scalar: Clone + $trt<A, Output=A>,
178+
A: Clone,
179+
S: DataOwned<Elem=A> + DataMut,
180+
D: Dimension,
190181
{
191182
type Output = ArrayBase<S, D>;
192-
fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
193-
if_commutative!($commutative {
194-
rhs.$mth(self)
195-
} or {{
196-
let mut rhs = rhs;
197-
rhs.unordered_foreach_mut(move |elt| {
198-
*elt = self $operator *elt;
199-
});
200-
rhs
201-
}})
183+
fn $mth(self, mut rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
184+
rhs.unordered_foreach_mut(move |elt| {
185+
*elt = self.clone() $operator elt.clone();
186+
});
187+
rhs
202188
}
203189
}
204190

205-
// Perform elementwise
206-
// between the scalar `self` and array `rhs`,
207-
// and return the result as a new `Array`.
208-
impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
209-
where S: Data<Elem=$scalar>,
210-
D: Dimension,
191+
/// Perform elementwise
192+
#[doc=$doc]
193+
/// between the scalar `self` and array `rhs`,
194+
/// and return the result as a new `Array`.
195+
impl<'a, A, S, D, B> $trt<&'a ArrayBase<S, D>> for $scalar
196+
where
197+
$scalar: Clone + $trt<A, Output=B>,
198+
A: Clone,
199+
S: Data<Elem=A>,
200+
D: Dimension,
211201
{
212-
type Output = Array<$scalar, D>;
213-
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<$scalar, D> {
214-
if_commutative!($commutative {
215-
rhs.$mth(self)
216-
} or {
217-
self.$mth(rhs.to_owned())
218-
})
202+
type Output = Array<B, D>;
203+
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<B, D> {
204+
rhs.map(move |elt| self.clone() $operator elt.clone())
219205
}
220206
}
221207
);
@@ -241,16 +227,16 @@ mod arithmetic_ops {
241227

242228
macro_rules! all_scalar_ops {
243229
($int_scalar:ty) => (
244-
impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
245-
impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
246-
impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
247-
impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
248-
impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
249-
impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
250-
impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
251-
impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
252-
impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
253-
impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
230+
impl_scalar_lhs_op!($int_scalar, +, Add, add, "addition");
231+
impl_scalar_lhs_op!($int_scalar, -, Sub, sub, "subtraction");
232+
impl_scalar_lhs_op!($int_scalar, *, Mul, mul, "multiplication");
233+
impl_scalar_lhs_op!($int_scalar, /, Div, div, "division");
234+
impl_scalar_lhs_op!($int_scalar, %, Rem, rem, "remainder");
235+
impl_scalar_lhs_op!($int_scalar, &, BitAnd, bitand, "bit and");
236+
impl_scalar_lhs_op!($int_scalar, |, BitOr, bitor, "bit or");
237+
impl_scalar_lhs_op!($int_scalar, ^, BitXor, bitxor, "bit xor");
238+
impl_scalar_lhs_op!($int_scalar, <<, Shl, shl, "left shift");
239+
impl_scalar_lhs_op!($int_scalar, >>, Shr, shr, "right shift");
254240
);
255241
}
256242
all_scalar_ops!(i8);
@@ -264,31 +250,31 @@ mod arithmetic_ops {
264250
all_scalar_ops!(i128);
265251
all_scalar_ops!(u128);
266252

267-
impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
268-
impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
269-
impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
253+
impl_scalar_lhs_op!(bool, &, BitAnd, bitand, "bit and");
254+
impl_scalar_lhs_op!(bool, |, BitOr, bitor, "bit or");
255+
impl_scalar_lhs_op!(bool, ^, BitXor, bitxor, "bit xor");
270256

271-
impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
272-
impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
273-
impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
274-
impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
275-
impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
257+
impl_scalar_lhs_op!(f32, +, Add, add, "addition");
258+
impl_scalar_lhs_op!(f32, -, Sub, sub, "subtraction");
259+
impl_scalar_lhs_op!(f32, *, Mul, mul, "multiplication");
260+
impl_scalar_lhs_op!(f32, /, Div, div, "division");
261+
impl_scalar_lhs_op!(f32, %, Rem, rem, "remainder");
276262

277-
impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
278-
impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
279-
impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
280-
impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
281-
impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
263+
impl_scalar_lhs_op!(f64, +, Add, add, "addition");
264+
impl_scalar_lhs_op!(f64, -, Sub, sub, "subtraction");
265+
impl_scalar_lhs_op!(f64, *, Mul, mul, "multiplication");
266+
impl_scalar_lhs_op!(f64, /, Div, div, "division");
267+
impl_scalar_lhs_op!(f64, %, Rem, rem, "remainder");
282268

283-
impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
284-
impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
285-
impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
286-
impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
269+
impl_scalar_lhs_op!(Complex<f32>, +, Add, add, "addition");
270+
impl_scalar_lhs_op!(Complex<f32>, -, Sub, sub, "subtraction");
271+
impl_scalar_lhs_op!(Complex<f32>, *, Mul, mul, "multiplication");
272+
impl_scalar_lhs_op!(Complex<f32>, /, Div, div, "division");
287273

288-
impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
289-
impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
290-
impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
291-
impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
274+
impl_scalar_lhs_op!(Complex<f64>, +, Add, add, "addition");
275+
impl_scalar_lhs_op!(Complex<f64>, -, Sub, sub, "subtraction");
276+
impl_scalar_lhs_op!(Complex<f64>, *, Mul, mul, "multiplication");
277+
impl_scalar_lhs_op!(Complex<f64>, /, Div, div, "division");
292278

293279
impl<A, S, D> Neg for ArrayBase<S, D>
294280
where

0 commit comments

Comments
 (0)