Skip to content

Commit 2ac3c9f

Browse files
committed
Make mapv_into_any() work for ArcArray, resolves #1280
1 parent 0740695 commit 2ac3c9f

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

src/data_traits.rs

+13
Original file line numberDiff line numberDiff line change
@@ -766,3 +766,16 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for CowRepr<'a, A> {
766766
}
767767
}
768768
}
769+
770+
/// Plug the data element type of one owned array representation into another.
771+
///
772+
/// For example, `<OwnedRepr<f64> as Plug<f32>>::Type` has type `OwnedRepr<f32>`.
773+
pub trait Plug<A> {
774+
type Type;
775+
}
776+
impl <A, B> Plug<A> for crate::OwnedRepr<B> {
777+
type Type = crate::OwnedRepr<A>;
778+
}
779+
impl <A, B> Plug<A> for crate::OwnedArcRepr<B> {
780+
type Type = crate::OwnedArcRepr<A>;
781+
}

src/impl_methods.rs

+30-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::imp_prelude::*;
1616

1717
use crate::{arraytraits, DimMax};
1818
use crate::argument_traits::AssignElem;
19+
use crate::data_traits::Plug;
1920
use crate::dimension;
2021
use crate::dimension::IntoDimension;
2122
use crate::dimension::{
@@ -2586,15 +2587,29 @@ where
25862587
/// map is performed as in [`mapv`].
25872588
///
25882589
/// Elements are visited in arbitrary order.
2589-
///
2590+
///
2591+
/// Note that the compiler will need some hint about the return type, which
2592+
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
2593+
/// [`ArcArray`]. Example:
2594+
///
2595+
/// ```rust
2596+
/// # use ndarray::{array, Array};
2597+
/// let a = array![[1., 2., 3.]];
2598+
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
2599+
/// ```
2600+
///
25902601
/// [`mapv_into`]: ArrayBase::mapv_into
25912602
/// [`mapv`]: ArrayBase::mapv
2592-
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
2603+
pub fn mapv_into_any<B, F, T>(self, mut f: F) -> ArrayBase<T, D>
25932604
where
2594-
S: DataMut,
2605+
S: DataMut<Elem = A>,
25952606
F: FnMut(A) -> B,
25962607
A: Clone + 'static,
25972608
B: 'static,
2609+
T: DataOwned<Elem = B> + Plug<A>, // lets us introspect on the types of array representations containing different data elements
2610+
<T as Plug<A>>::Type: RawData, // required by mapv_into()
2611+
ArrayBase<<T as Plug<A>>::Type, D>: From<ArrayBase<S, D>>, // required by into() to convert from the DataMut array representation of S to the DataOwned array representation of T
2612+
ArrayBase<T, D>: From<Array<B, D>>, // required by mapv()
25982613
{
25992614
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
26002615
// A and B are the same type.
@@ -2604,16 +2619,21 @@ where
26042619
// Safe because A and B are the same type.
26052620
unsafe { unlimited_transmute::<B, A>(b) }
26062621
};
2607-
// Delegate to mapv_into() using the wrapped closure.
2608-
// Convert output to a uniquely owned array of type Array<A, D>.
2609-
let output = self.mapv_into(f).into_owned();
2610-
// Change the return type from Array<A, D> to Array<B, D>.
2611-
// Again, safe because A and B are the same type.
2612-
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
2622+
// Delegate to mapv_into() to map from element type A to type A.
2623+
let output = self.mapv_into(f);
2624+
// Convert from S's array representation to T's array representation.
2625+
// Suppose `T is `OwnedRepr<B>`.
2626+
// Then `<T as Plug<A>>::Type` is `OwnedRepr<A>`.
2627+
let output: ArrayBase<<T as Plug<A>>::Type, D> = output.into();
2628+
// Map from T's array representation with data element type A to T's
2629+
// array representation with element type B.
2630+
// This is safe because A and B are the same type and the array
2631+
// representations are also the same higher kinded type.
2632+
unsafe { unlimited_transmute::<ArrayBase<<T as Plug<A>>::Type, D>, ArrayBase<T,D>>(output) }
26132633
} else {
26142634
// A and B are not the same type.
26152635
// Fallback to mapv().
2616-
self.mapv(f)
2636+
self.mapv(f).into()
26172637
}
26182638
}
26192639

tests/array.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -995,14 +995,40 @@ fn map1() {
995995
fn mapv_into_any_same_type() {
996996
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
997997
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
998-
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
998+
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
999+
assert_eq!(b, a_plus_one);
9991000
}
10001001

10011002
#[test]
10021003
fn mapv_into_any_diff_types() {
10031004
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
10041005
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1005-
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
1006+
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1007+
assert_eq!(b, a_even);
1008+
}
1009+
1010+
#[test]
1011+
fn mapv_into_any_arcarray_same_type() {
1012+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1013+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1014+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
1015+
assert_eq!(b, a_plus_one);
1016+
}
1017+
1018+
#[test]
1019+
fn mapv_into_any_arcarray_diff_types() {
1020+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1021+
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1022+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1023+
assert_eq!(b, a_even);
1024+
}
1025+
1026+
#[test]
1027+
fn mapv_into_any_diff_outer_types() {
1028+
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
1029+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1030+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
1031+
assert_eq!(b, a_plus_one);
10061032
}
10071033

10081034
#[test]

0 commit comments

Comments
 (0)