Skip to content

Commit 3373ea9

Browse files
committed
Make mapv_into_any() work for ArcArray, resolves #1280
1 parent 0740695 commit 3373ea9

File tree

2 files changed

+60
-12
lines changed

2 files changed

+60
-12
lines changed

src/impl_methods.rs

+32-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::imp_prelude::*;
1616

1717
use crate::{arraytraits, DimMax};
1818
use crate::argument_traits::AssignElem;
19+
use crate::data_traits::RawDataSubst;
1920
use crate::dimension;
2021
use crate::dimension::IntoDimension;
2122
use crate::dimension::{
@@ -2586,15 +2587,29 @@ where
25862587
/// map is performed as in [`mapv`].
25872588
///
25882589
/// Elements are visited in arbitrary order.
2589-
///
2590+
///
2591+
/// Note that the compiler will need some hint about the return type, which
2592+
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
2593+
/// [`ArcArray`]. Example:
2594+
///
2595+
/// ```rust
2596+
/// # use ndarray::{array, Array};
2597+
/// let a = array![[1., 2., 3.]];
2598+
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
2599+
/// ```
2600+
///
25902601
/// [`mapv_into`]: ArrayBase::mapv_into
25912602
/// [`mapv`]: ArrayBase::mapv
2592-
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
2603+
pub fn mapv_into_any<B, F, T>(self, mut f: F) -> ArrayBase<T, D>
25932604
where
2594-
S: DataMut,
2605+
S: DataMut<Elem = A>,
25952606
F: FnMut(A) -> B,
25962607
A: Clone + 'static,
25972608
B: 'static,
2609+
T: DataOwned<Elem = B> + RawDataSubst<A> + 'static, // lets us introspect on the types of array representations containing different data elements
2610+
<T as RawDataSubst<A>>::Output: RawData, // required by mapv_into()
2611+
ArrayBase<<T as RawDataSubst<A>>::Output, D>: From<ArrayBase<S, D>>, // required by into() to convert from the DataMut array representation of S to the DataOwned array representation of T
2612+
ArrayBase<T, D>: From<Array<B, D>>, // required by mapv()
25982613
{
25992614
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
26002615
// A and B are the same type.
@@ -2604,16 +2619,23 @@ where
26042619
// Safe because A and B are the same type.
26052620
unsafe { unlimited_transmute::<B, A>(b) }
26062621
};
2607-
// Delegate to mapv_into() using the wrapped closure.
2608-
// Convert output to a uniquely owned array of type Array<A, D>.
2609-
let output = self.mapv_into(f).into_owned();
2610-
// Change the return type from Array<A, D> to Array<B, D>.
2611-
// Again, safe because A and B are the same type.
2612-
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
2622+
// Delegate to mapv_into() to map from element type A to type A.
2623+
let output = self.mapv_into(f);
2624+
// Convert from S's data storage to T's data storage.
2625+
// Suppose `T is `OwnedRepr<B>`.
2626+
// Then `<T as RawDataSubst<A>>::Output` is `OwnedRepr<A>`.
2627+
let output: ArrayBase<<T as RawDataSubst<A>>::Output, D> = output.into();
2628+
// Since A == B and T stores elements of type B, it should be true
2629+
// that <T as RawDataSubst<A>>::Output == T.
2630+
// Verify that this is indeed the case.
2631+
assert!(core::any::TypeId::of::<<T as RawDataSubst<A>>::Output>() == core::any::TypeId::of::<T>());
2632+
// Now we can safely transmute the element type from A to the
2633+
// identical type B, keeping the same data storage.
2634+
unsafe { unlimited_transmute::<ArrayBase<<T as RawDataSubst<A>>::Output, D>, ArrayBase<T,D>>(output) }
26132635
} else {
26142636
// A and B are not the same type.
26152637
// Fallback to mapv().
2616-
self.mapv(f)
2638+
self.mapv(f).into()
26172639
}
26182640
}
26192641

tests/array.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -995,14 +995,40 @@ fn map1() {
995995
fn mapv_into_any_same_type() {
996996
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
997997
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
998-
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
998+
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
999+
assert_eq!(b, a_plus_one);
9991000
}
10001001

10011002
#[test]
10021003
fn mapv_into_any_diff_types() {
10031004
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
10041005
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1005-
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
1006+
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1007+
assert_eq!(b, a_even);
1008+
}
1009+
1010+
#[test]
1011+
fn mapv_into_any_arcarray_same_type() {
1012+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1013+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1014+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
1015+
assert_eq!(b, a_plus_one);
1016+
}
1017+
1018+
#[test]
1019+
fn mapv_into_any_arcarray_diff_types() {
1020+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1021+
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1022+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1023+
assert_eq!(b, a_even);
1024+
}
1025+
1026+
#[test]
1027+
fn mapv_into_any_diff_outer_types() {
1028+
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
1029+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1030+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
1031+
assert_eq!(b, a_plus_one);
10061032
}
10071033

10081034
#[test]

0 commit comments

Comments
 (0)