Skip to content

Commit ac9e3da

Browse files
committed
Make mapv_into_any() work for ArcArray, resolves rust-ndarray#1280
1 parent 492b274 commit ac9e3da

File tree

2 files changed

+60
-12
lines changed

2 files changed

+60
-12
lines changed

src/impl_methods.rs

+32-10
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use std::mem::{size_of, ManuallyDrop};
1717
use crate::imp_prelude::*;
1818

1919
use crate::argument_traits::AssignElem;
20+
use crate::data_traits::RawDataSubst;
2021
use crate::dimension;
2122
use crate::dimension::broadcast::co_broadcast;
2223
use crate::dimension::reshape_dim;
@@ -2814,15 +2815,29 @@ where
28142815
/// map is performed as in [`mapv`].
28152816
///
28162817
/// Elements are visited in arbitrary order.
2817-
///
2818+
///
2819+
/// Note that the compiler will need some hint about the return type, which
2820+
/// is generic over [`DataOwned`], and can thus be an [`Array`] or
2821+
/// [`ArcArray`]. Example:
2822+
///
2823+
/// ```rust
2824+
/// # use ndarray::{array, Array};
2825+
/// let a = array![[1., 2., 3.]];
2826+
/// let a_plus_one: Array<_, _> = a.mapv_into_any(|a| a + 1.);
2827+
/// ```
2828+
///
28182829
/// [`mapv_into`]: ArrayBase::mapv_into
28192830
/// [`mapv`]: ArrayBase::mapv
2820-
pub fn mapv_into_any<B, F>(self, mut f: F) -> Array<B, D>
2831+
pub fn mapv_into_any<B, F, T>(self, mut f: F) -> ArrayBase<T, D>
28212832
where
2822-
S: DataMut,
2833+
S: DataMut<Elem = A>,
28232834
F: FnMut(A) -> B,
28242835
A: Clone + 'static,
28252836
B: 'static,
2837+
T: DataOwned<Elem = B> + RawDataSubst<A> + 'static, // lets us introspect on the types of array representations containing different data elements
2838+
<T as RawDataSubst<A>>::Output: RawData, // required by mapv_into()
2839+
ArrayBase<<T as RawDataSubst<A>>::Output, D>: From<ArrayBase<S, D>>, // required by into() to convert from the DataMut array representation of S to the DataOwned array representation of T
2840+
ArrayBase<T, D>: From<Array<B, D>>, // required by mapv()
28262841
{
28272842
if core::any::TypeId::of::<A>() == core::any::TypeId::of::<B>() {
28282843
// A and B are the same type.
@@ -2832,16 +2847,23 @@ where
28322847
// Safe because A and B are the same type.
28332848
unsafe { unlimited_transmute::<B, A>(b) }
28342849
};
2835-
// Delegate to mapv_into() using the wrapped closure.
2836-
// Convert output to a uniquely owned array of type Array<A, D>.
2837-
let output = self.mapv_into(f).into_owned();
2838-
// Change the return type from Array<A, D> to Array<B, D>.
2839-
// Again, safe because A and B are the same type.
2840-
unsafe { unlimited_transmute::<Array<A, D>, Array<B, D>>(output) }
2850+
// Delegate to mapv_into() to map from element type A to type A.
2851+
let output = self.mapv_into(f);
2852+
// Convert from S's data storage to T's data storage.
2853+
// Suppose `T is `OwnedRepr<B>`.
2854+
// Then `<T as RawDataSubst<A>>::Output` is `OwnedRepr<A>`.
2855+
let output: ArrayBase<<T as RawDataSubst<A>>::Output, D> = output.into();
2856+
// Since A == B and T stores elements of type B, it should be true
2857+
// that <T as RawDataSubst<A>>::Output == T.
2858+
// Verify that this is indeed the case.
2859+
assert!(core::any::TypeId::of::<<T as RawDataSubst<A>>::Output>() == core::any::TypeId::of::<T>());
2860+
// Now we can safely transmute the element type from A to the
2861+
// identical type B, keeping the same data storage.
2862+
unsafe { unlimited_transmute::<ArrayBase<<T as RawDataSubst<A>>::Output, D>, ArrayBase<T,D>>(output) }
28412863
} else {
28422864
// A and B are not the same type.
28432865
// Fallback to mapv().
2844-
self.mapv(f)
2866+
self.mapv(f).into()
28452867
}
28462868
}
28472869

tests/array.rs

+28-2
Original file line numberDiff line numberDiff line change
@@ -1054,15 +1054,41 @@ fn mapv_into_any_same_type()
10541054
{
10551055
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
10561056
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1057-
assert_eq!(a.mapv_into_any(|a| a + 1.), a_plus_one);
1057+
let b: Array<_, _> = a.mapv_into_any(|a| a + 1.);
1058+
assert_eq!(b, a_plus_one);
10581059
}
10591060

10601061
#[test]
10611062
fn mapv_into_any_diff_types()
10621063
{
10631064
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
10641065
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1065-
assert_eq!(a.mapv_into_any(|a| a.round() as i32 % 2 == 0), a_even);
1066+
let b: Array<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1067+
assert_eq!(b, a_even);
1068+
}
1069+
1070+
#[test]
1071+
fn mapv_into_any_arcarray_same_type() {
1072+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1073+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1074+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
1075+
assert_eq!(b, a_plus_one);
1076+
}
1077+
1078+
#[test]
1079+
fn mapv_into_any_arcarray_diff_types() {
1080+
let a: ArcArray<f64, _> = array![[1., 2., 3.], [4., 5., 6.]].into_shared();
1081+
let a_even: Array<bool, _> = array![[false, true, false], [true, false, true]];
1082+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a.round() as i32 % 2 == 0);
1083+
assert_eq!(b, a_even);
1084+
}
1085+
1086+
#[test]
1087+
fn mapv_into_any_diff_outer_types() {
1088+
let a: Array<f64, _> = array![[1., 2., 3.], [4., 5., 6.]];
1089+
let a_plus_one: Array<f64, _> = array![[2., 3., 4.], [5., 6., 7.]];
1090+
let b: ArcArray<_, _> = a.mapv_into_any(|a| a + 1.);
1091+
assert_eq!(b, a_plus_one);
10661092
}
10671093

10681094
#[test]

0 commit comments

Comments
 (0)