Skip to content

Commit 8032aec

Browse files
committed
shape: Add tests for .to_shape()
1 parent 3794952 commit 8032aec

File tree

2 files changed

+160
-59
lines changed

2 files changed

+160
-59
lines changed

tests/array.rs

+1-59
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)]
99

1010
use defmac::defmac;
11-
use itertools::{enumerate, zip, Itertools};
11+
use itertools::{zip, Itertools};
1212
use ndarray::prelude::*;
1313
use ndarray::{arr3, rcarr2};
1414
use ndarray::indices;
@@ -1370,64 +1370,6 @@ fn transpose_view_mut() {
13701370
assert_eq!(at, arr2(&[[1, 4], [2, 5], [3, 7]]));
13711371
}
13721372

1373-
#[test]
1374-
fn reshape() {
1375-
let data = [1, 2, 3, 4, 5, 6, 7, 8];
1376-
let v = aview1(&data);
1377-
let u = v.into_shape((3, 3));
1378-
assert!(u.is_err());
1379-
let u = v.into_shape((2, 2, 2));
1380-
assert!(u.is_ok());
1381-
let u = u.unwrap();
1382-
assert_eq!(u.shape(), &[2, 2, 2]);
1383-
let s = u.into_shape((4, 2)).unwrap();
1384-
assert_eq!(s.shape(), &[4, 2]);
1385-
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
1386-
}
1387-
1388-
#[test]
1389-
#[should_panic(expected = "IncompatibleShape")]
1390-
fn reshape_error1() {
1391-
let data = [1, 2, 3, 4, 5, 6, 7, 8];
1392-
let v = aview1(&data);
1393-
let _u = v.into_shape((2, 5)).unwrap();
1394-
}
1395-
1396-
#[test]
1397-
#[should_panic(expected = "IncompatibleLayout")]
1398-
fn reshape_error2() {
1399-
let data = [1, 2, 3, 4, 5, 6, 7, 8];
1400-
let v = aview1(&data);
1401-
let mut u = v.into_shape((2, 2, 2)).unwrap();
1402-
u.swap_axes(0, 1);
1403-
let _s = u.into_shape((2, 4)).unwrap();
1404-
}
1405-
1406-
#[test]
1407-
fn reshape_f() {
1408-
let mut u = Array::zeros((3, 4).f());
1409-
for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) {
1410-
*elt = i as i32;
1411-
}
1412-
let v = u.view();
1413-
println!("{:?}", v);
1414-
1415-
// noop ok
1416-
let v2 = v.into_shape((3, 4));
1417-
assert!(v2.is_ok());
1418-
assert_eq!(v, v2.unwrap());
1419-
1420-
let u = v.into_shape((3, 2, 2));
1421-
assert!(u.is_ok());
1422-
let u = u.unwrap();
1423-
println!("{:?}", u);
1424-
assert_eq!(u.shape(), &[3, 2, 2]);
1425-
let s = u.into_shape((4, 3)).unwrap();
1426-
println!("{:?}", s);
1427-
assert_eq!(s.shape(), &[4, 3]);
1428-
assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]]));
1429-
}
1430-
14311373
#[test]
14321374
#[allow(clippy::cognitive_complexity)]
14331375
fn insert_axis() {

tests/reshape.rs

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
use ndarray::prelude::*;
2+
3+
use itertools::enumerate;
4+
5+
use ndarray::Order;
6+
7+
#[test]
8+
fn reshape() {
9+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
10+
let v = aview1(&data);
11+
let u = v.into_shape((3, 3));
12+
assert!(u.is_err());
13+
let u = v.into_shape((2, 2, 2));
14+
assert!(u.is_ok());
15+
let u = u.unwrap();
16+
assert_eq!(u.shape(), &[2, 2, 2]);
17+
let s = u.into_shape((4, 2)).unwrap();
18+
assert_eq!(s.shape(), &[4, 2]);
19+
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
20+
}
21+
22+
#[test]
23+
#[should_panic(expected = "IncompatibleShape")]
24+
fn reshape_error1() {
25+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
26+
let v = aview1(&data);
27+
let _u = v.into_shape((2, 5)).unwrap();
28+
}
29+
30+
#[test]
31+
#[should_panic(expected = "IncompatibleLayout")]
32+
fn reshape_error2() {
33+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
34+
let v = aview1(&data);
35+
let mut u = v.into_shape((2, 2, 2)).unwrap();
36+
u.swap_axes(0, 1);
37+
let _s = u.into_shape((2, 4)).unwrap();
38+
}
39+
40+
#[test]
41+
fn reshape_f() {
42+
let mut u = Array::zeros((3, 4).f());
43+
for (i, elt) in enumerate(u.as_slice_memory_order_mut().unwrap()) {
44+
*elt = i as i32;
45+
}
46+
let v = u.view();
47+
println!("{:?}", v);
48+
49+
// noop ok
50+
let v2 = v.into_shape((3, 4));
51+
assert!(v2.is_ok());
52+
assert_eq!(v, v2.unwrap());
53+
54+
let u = v.into_shape((3, 2, 2));
55+
assert!(u.is_ok());
56+
let u = u.unwrap();
57+
println!("{:?}", u);
58+
assert_eq!(u.shape(), &[3, 2, 2]);
59+
let s = u.into_shape((4, 3)).unwrap();
60+
println!("{:?}", s);
61+
assert_eq!(s.shape(), &[4, 3]);
62+
assert_eq!(s, aview2(&[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]]));
63+
}
64+
65+
66+
#[test]
67+
fn to_shape_easy() {
68+
// 1D -> C -> C
69+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
70+
let v = aview1(&data);
71+
let u = v.to_shape(((3, 3), Order::RowMajor));
72+
assert!(u.is_err());
73+
74+
let u = v.to_shape(((2, 2, 2), Order::C));
75+
assert!(u.is_ok());
76+
77+
let u = u.unwrap();
78+
assert!(u.is_view());
79+
assert_eq!(u.shape(), &[2, 2, 2]);
80+
assert_eq!(u, array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
81+
82+
let s = u.to_shape((4, 2)).unwrap();
83+
assert_eq!(s.shape(), &[4, 2]);
84+
assert_eq!(s, aview2(&[[1, 2], [3, 4], [5, 6], [7, 8]]));
85+
86+
// 1D -> F -> F
87+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
88+
let v = aview1(&data);
89+
let u = v.to_shape(((3, 3), Order::ColumnMajor));
90+
assert!(u.is_err());
91+
92+
let u = v.to_shape(((2, 2, 2), Order::ColumnMajor));
93+
assert!(u.is_ok());
94+
95+
let u = u.unwrap();
96+
assert!(u.is_view());
97+
assert_eq!(u.shape(), &[2, 2, 2]);
98+
assert_eq!(u, array![[[1, 5], [3, 7]], [[2, 6], [4, 8]]]);
99+
100+
let s = u.to_shape(((4, 2), Order::ColumnMajor)).unwrap();
101+
assert_eq!(s.shape(), &[4, 2]);
102+
assert_eq!(s, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
103+
}
104+
105+
#[test]
106+
fn to_shape_copy() {
107+
// 1D -> C -> F
108+
let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]);
109+
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
110+
assert_eq!(u.shape(), &[4, 2]);
111+
assert_eq!(u, array![[1, 2], [3, 4], [5, 6], [7, 8]]);
112+
113+
let u = u.to_shape(((2, 4), Order::ColumnMajor)).unwrap();
114+
assert_eq!(u.shape(), &[2, 4]);
115+
assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]);
116+
117+
// 1D -> F -> C
118+
let v = ArrayView::from(&[1, 2, 3, 4, 5, 6, 7, 8]);
119+
let u = v.to_shape(((4, 2), Order::ColumnMajor)).unwrap();
120+
assert_eq!(u.shape(), &[4, 2]);
121+
assert_eq!(u, array![[1, 5], [2, 6], [3, 7], [4, 8]]);
122+
123+
let u = u.to_shape((2, 4)).unwrap();
124+
assert_eq!(u.shape(), &[2, 4]);
125+
assert_eq!(u, array![[1, 5, 2, 6], [3, 7, 4, 8]]);
126+
}
127+
128+
#[test]
129+
fn to_shape_add_axis() {
130+
// 1D -> C -> C
131+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
132+
let v = aview1(&data);
133+
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
134+
135+
assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view());
136+
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned());
137+
}
138+
139+
140+
#[test]
141+
fn to_shape_copy_stride() {
142+
let v = array![[1, 2, 3, 4], [5, 6, 7, 8]];
143+
let vs = v.slice(s![.., ..3]);
144+
let lin1 = vs.to_shape(6).unwrap();
145+
assert_eq!(lin1, array![1, 2, 3, 5, 6, 7]);
146+
assert!(lin1.is_owned());
147+
148+
let lin2 = vs.to_shape((6, Order::ColumnMajor)).unwrap();
149+
assert_eq!(lin2, array![1, 5, 2, 6, 3, 7]);
150+
assert!(lin2.is_owned());
151+
}
152+
153+
#[test]
154+
#[should_panic(expected = "IncompatibleShape")]
155+
fn to_shape_error1() {
156+
let data = [1, 2, 3, 4, 5, 6, 7, 8];
157+
let v = aview1(&data);
158+
let _u = v.to_shape((2, 5)).unwrap();
159+
}

0 commit comments

Comments
 (0)