Skip to content

Commit b70085c

Browse files
committed
shape: Use reshape_dim function in .to_shape()
1 parent f96977f commit b70085c

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

src/impl_methods.rs

+28-17
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ use crate::dimension::{
2323
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2424
};
2525
use crate::dimension::broadcast::co_broadcast;
26+
use crate::dimension::reshape_dim;
2627
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2728
use crate::math_cell::MathCell;
2829
use crate::itertools::zip;
2930
use crate::AxisDescription;
30-
use crate::Layout;
3131
use crate::order::Order;
3232
use crate::shape_builder::ShapeArg;
3333
use crate::zip::{IntoNdProducer, Zip};
@@ -1641,27 +1641,38 @@ where
16411641
A: Clone,
16421642
S: Data,
16431643
{
1644-
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
1644+
let len = self.dim.size();
1645+
if size_of_shape_checked(&shape) != Ok(len) {
16451646
return Err(error::incompatible_shapes(&self.dim, &shape));
16461647
}
1647-
let layout = self.layout_impl();
16481648

1649-
unsafe {
1650-
if layout.is(Layout::CORDER) && order == Order::RowMajor {
1651-
let strides = shape.default_strides();
1652-
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1653-
} else if layout.is(Layout::FORDER) && order == Order::ColumnMajor {
1654-
let strides = shape.fortran_strides();
1655-
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1656-
} else {
1657-
let (shape, view) = match order {
1658-
Order::RowMajor => (shape.set_f(false), self.view()),
1659-
Order::ColumnMajor => (shape.set_f(true), self.t()),
1660-
};
1661-
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1662-
shape, view.into_iter(), A::clone)))
1649+
// Create a view if the length is 0, safe because the array and new shape is empty.
1650+
if len == 0 {
1651+
unsafe {
1652+
return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr())));
16631653
}
16641654
}
1655+
1656+
// Try to reshape the array as a view into the existing data
1657+
match reshape_dim(&self.dim, &self.strides, &shape, order) {
1658+
Ok(to_strides) => unsafe {
1659+
return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides)));
1660+
}
1661+
Err(err) if err.kind() == ErrorKind::IncompatibleShape => {
1662+
return Err(error::incompatible_shapes(&self.dim, &shape));
1663+
}
1664+
_otherwise => { }
1665+
}
1666+
1667+
// otherwise create a new array and copy the elements
1668+
unsafe {
1669+
let (shape, view) = match order {
1670+
Order::RowMajor => (shape.set_f(false), self.view()),
1671+
Order::ColumnMajor => (shape.set_f(true), self.t()),
1672+
};
1673+
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1674+
shape, view.into_iter(), A::clone)))
1675+
}
16651676
}
16661677

16671678
/// Transform the array into `shape`; any shape with the same number of

tests/reshape.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ fn to_shape_add_axis() {
133133
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
134134

135135
assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view());
136-
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned());
136+
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view());
137137
}
138138

139139

@@ -150,10 +150,29 @@ fn to_shape_copy_stride() {
150150
assert!(lin2.is_owned());
151151
}
152152

153+
154+
#[test]
155+
fn to_shape_zero_len() {
156+
let v = array![[1, 2, 3, 4], [5, 6, 7, 8]];
157+
let vs = v.slice(s![.., ..0]);
158+
let lin1 = vs.to_shape(0).unwrap();
159+
assert_eq!(lin1, array![]);
160+
assert!(lin1.is_view());
161+
}
162+
153163
#[test]
154164
#[should_panic(expected = "IncompatibleShape")]
155165
fn to_shape_error1() {
156166
let data = [1, 2, 3, 4, 5, 6, 7, 8];
157167
let v = aview1(&data);
158168
let _u = v.to_shape((2, 5)).unwrap();
159169
}
170+
171+
#[test]
172+
#[should_panic(expected = "IncompatibleShape")]
173+
fn to_shape_error2() {
174+
// overflow
175+
let data = [3, 4, 5, 6, 7, 8];
176+
let v = aview1(&data);
177+
let _u = v.to_shape((2, usize::MAX)).unwrap();
178+
}

0 commit comments

Comments
 (0)