Skip to content

Commit f1f0e5a

Browse files
committed
shape: Add trait ShapeArg
1 parent c04752f commit f1f0e5a

File tree

4 files changed

+192
-2
lines changed

4 files changed

+192
-2
lines changed

src/impl_methods.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ use crate::dimension::broadcast::co_broadcast;
2626
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2727
use crate::math_cell::MathCell;
2828
use crate::itertools::zip;
29-
use crate::zip::{IntoNdProducer, Zip};
3029
use crate::AxisDescription;
30+
use crate::zip::{IntoNdProducer, Zip};
3131

3232
use crate::iter::{
3333
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ pub use crate::stacking::{concatenate, stack, stack_new_axis};
163163

164164
pub use crate::math_cell::MathCell;
165165
pub use crate::impl_views::IndexLonger;
166-
pub use crate::shape_builder::{Shape, ShapeBuilder, StrideShape};
166+
pub use crate::shape_builder::{Shape, ShapeBuilder, ShapeArg, StrideShape};
167167

168168
#[macro_use]
169169
mod macro_utils;

src/shape_builder.rs

+31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::dimension::IntoDimension;
22
use crate::Dimension;
3+
use crate::order::Order;
34

45
/// A contiguous array shape of n dimensions.
56
///
@@ -184,3 +185,33 @@ where
184185
self.dim.size()
185186
}
186187
}
188+
189+
190+
/// Array shape argument with optional order parameter
191+
///
192+
/// Shape or array dimension argument, with optional [`Order`] parameter.
193+
///
194+
/// This is an argument conversion trait that is used to accept an array shape and
195+
/// (optionally) an ordering argument.
196+
///
197+
/// See for example [`.to_shape()`](crate::ArrayBase::to_shape).
198+
pub trait ShapeArg {
199+
type Dim: Dimension;
200+
fn into_shape_and_order(self) -> (Self::Dim, Option<Order>);
201+
}
202+
203+
impl<T> ShapeArg for T where T: IntoDimension {
204+
type Dim = T::Dim;
205+
206+
fn into_shape_and_order(self) -> (Self::Dim, Option<Order>) {
207+
(self.into_dimension(), None)
208+
}
209+
}
210+
211+
impl<T> ShapeArg for (T, Order) where T: IntoDimension {
212+
type Dim = T::Dim;
213+
214+
fn into_shape_and_order(self) -> (Self::Dim, Option<Order>) {
215+
(self.0.into_dimension(), Some(self.1))
216+
}
217+
}

tests/reshape.rs

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
use ndarray::prelude::*;
2+
3+
use itertools::enumerate;
4+
5+
use ndarray::Order;
6+
7+
#[test]
8+
fn reshape() {
9+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
10+
let v = aview1(&data);
11+
let u = v.into_shape((3, 3));
12+
assert!(u.is_err());
13+
let u = v.into_shape((2, 2, 2));
14+
assert!(u.is_ok());
15+
let u = u.unwrap();
16+
assert_eq!(u.shape(), &[2, 2, 2]);
17+
let s = u.into_shape((4, 2)).unwrap();
18+
assert_eq!(s.shape(), &[4, 2]);
19+
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
20+
}
21+
22+
#[test]
23+
#[should_panic(expected = "IncompatibleShape")]
24+
fn reshape_error1() {
25+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
26+
let v = aview1(&data);
27+
let _u = v.into_shape((2, 5)).unwrap();
28+
}
29+
30+
#[test]
31+
#[should_panic(expected = "IncompatibleLayout")]
32+
fn reshape_error2() {
33+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
34+
let v = aview1(&data);
35+
let mut u = v.into_shape((2, 2, 2)).unwrap();
36+
u.swap_axes(0, 1);
37+
let _s = u.into_shape((2, 4)).unwrap();
38+
}
39+
40+
#[test]
41+
fn reshape_f() {
42+
let mut u = Array::zeros((3, 4).f());
43+
for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) {
44+
*elt = i as i32;
45+
}
46+
let v = u.view();
47+
println!("{:?}", v);
48+
49+
// noop ok
50+
let v2 = v.into_shape((3, 4));
51+
assert!(v2.is_ok());
52+
assert_eq!(v, v2.unwrap());
53+
54+
let u = v.into_shape((3, 2, 2));
55+
assert!(u.is_ok());
56+
let u = u.unwrap();
57+
println!("{:?}", u);
58+
assert_eq!(u.shape(), &[3, 2, 2]);
59+
let s = u.into_shape((4, 3)).unwrap();
60+
println!("{:?}", s);
61+
assert_eq!(s.shape(), &[4, 3]);
62+
assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]]));
63+
}
64+
65+
66+
#[test]
67+
fn to_shape_easy() {
68+
// 1D -> C -> C
69+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
70+
let v = aview1(&data);
71+
let u = v.to_shape(((3, 3), Order::RowMajor));
72+
assert!(u.is_err());
73+
74+
let u = v.to_shape(((2, 2, 2), Order::C));
75+
assert!(u.is_ok());
76+
77+
let u = u.unwrap();
78+
assert!(u.is_view());
79+
assert_eq!(u.shape(), &[2, 2, 2]);
80+
assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
81+
82+
let s = u.to_shape((4, 2)).unwrap();
83+
assert_eq!(s.shape(), &[4, 2]);
84+
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
85+
86+
// 1D -> F -> F
87+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
88+
let v = aview1(&data);
89+
let u = v.to_shape(((3, 3), Order::ColumnMajor));
90+
assert!(u.is_err());
91+
92+
let u = v.to_shape(((2, 2, 2), Order::ColumnMajor));
93+
assert!(u.is_ok());
94+
95+
let u = u.unwrap();
96+
assert!(u.is_view());
97+
assert_eq!(u.shape(), &[2, 2, 2]);
98+
assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]);
99+
100+
let s = u.to_shape(((4, 2), Order::ColumnMajor)).unwrap();
101+
assert_eq!(s.shape(), &[4, 2]);
102+
assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
103+
}
104+
105+
#[test]
106+
fn to_shape_copy() {
107+
// 1D -> C -> F
108+
let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]);
109+
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
110+
assert_eq!(u.shape(), &[4, 2]);
111+
assert_eq!(u, array![[1, 2], [3, 4], [5, 6], [7, 8]]);
112+
113+
let u = u.to_shape(((2, 4), Order::ColumnMajor)).unwrap();
114+
assert_eq!(u.shape(), &[2, 4]);
115+
assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]);
116+
117+
// 1D -> F -> C
118+
let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]);
119+
let u = v.to_shape(((4, 2), Order::ColumnMajor)).unwrap();
120+
assert_eq!(u.shape(), &[4, 2]);
121+
assert_eq!(u, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
122+
123+
let u = u.to_shape((2, 4)).unwrap();
124+
assert_eq!(u.shape(), &[2, 4]);
125+
assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]);
126+
}
127+
128+
#[test]
129+
fn to_shape_add_axis() {
130+
// 1D -> C -> C
131+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
132+
let v = aview1(&data);
133+
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
134+
135+
assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view());
136+
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned());
137+
}
138+
139+
140+
#[test]
141+
fn to_shape_copy_stride() {
142+
let v = array![[1, 2, 3, 4], [5, 6, 7, 8]];
143+
let vs = v.slice(s![.., ..3]);
144+
let lin1 = vs.to_shape(6).unwrap();
145+
assert_eq!(lin1, array![1, 2, 3, 5, 6, 7]);
146+
assert!(lin1.is_owned());
147+
148+
let lin2 = vs.to_shape((6, Order::ColumnMajor)).unwrap();
149+
assert_eq!(lin2, array![1, 5, 2, 6, 3, 7]);
150+
assert!(lin2.is_owned());
151+
}
152+
153+
#[test]
154+
#[should_panic(expected = "IncompatibleShape")]
155+
fn to_shape_error1() {
156+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
157+
let v = aview1(&data);
158+
let _u = v.to_shape((2, 5)).unwrap();
159+
}

0 commit comments

Comments
 (0)