Skip to content

Commit e896df1

Browse files
committed
shape: Use reshape_dim function in .to_shape()
1 parent e6fa50e commit e896df1

File tree

2 files changed

+48
-17
lines changed

2 files changed

+48
-17
lines changed

src/impl_methods.rs

+28-16
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::dimension::{
2323
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
2424
};
2525
use crate::dimension::broadcast::co_broadcast;
26+
use crate::dimension::reshape_dim;
2627
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2728
use crate::math_cell::MathCell;
2829
use crate::itertools::zip;
@@ -1641,26 +1642,37 @@ where
16411642
A: Clone,
16421643
S: Data,
16431644
{
1644-
if size_of_shape_checked(&shape) != Ok(self.dim.size()) {
1645+
let len = self.dim.size();
1646+
if size_of_shape_checked(&shape) != Ok(len) {
16451647
return Err(error::incompatible_shapes(&self.dim, &shape));
16461648
}
1647-
let layout = self.layout_impl();
16481649

1649-
unsafe {
1650-
if layout.is(Layout::CORDER) && order == Order::RowMajor {
1651-
let strides = shape.default_strides();
1652-
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1653-
} else if layout.is(Layout::FORDER) && order == Order::ColumnMajor {
1654-
let strides = shape.fortran_strides();
1655-
Ok(CowArray::from(ArrayView::new(self.ptr, shape, strides)))
1656-
} else {
1657-
let (shape, view) = match order {
1658-
Order::RowMajor => (shape.set_f(false), self.view()),
1659-
Order::ColumnMajor => (shape.set_f(true), self.t()),
1660-
};
1661-
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1662-
shape, view.into_iter(), A::clone)))
1650+
// Create a view if the length is 0, safe because the array and new shape is empty.
1651+
if len == 0 {
1652+
unsafe {
1653+
return Ok(CowArray::from(ArrayView::from_shape_ptr(shape, self.as_ptr())));
1654+
}
1655+
}
1656+
1657+
// Try to reshape the array as a view into the existing data
1658+
match reshape_dim(&self.dim, &self.strides, &shape, order) {
1659+
Ok(to_strides) => unsafe {
1660+
return Ok(CowArray::from(ArrayView::new(self.ptr, shape, to_strides)));
16631661
}
1662+
Err(err) if err.kind() == ErrorKind::IncompatibleShape => {
1663+
return Err(error::incompatible_shapes(&self.dim, &shape));
1664+
}
1665+
_otherwise => { }
1666+
}
1667+
1668+
// otherwise create a new array and copy the elements
1669+
unsafe {
1670+
let (shape, view) = match order {
1671+
Order::RowMajor => (shape.set_f(false), self.view()),
1672+
Order::ColumnMajor => (shape.set_f(true), self.t()),
1673+
};
1674+
Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(
1675+
shape, view.into_iter(), A::clone)))
16641676
}
16651677
}
16661678

tests/reshape.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ fn to_shape_add_axis() {
133133
let u = v.to_shape(((4, 2), Order::RowMajor)).unwrap();
134134

135135
assert!(u.to_shape(((1, 4, 2), Order::RowMajor)).unwrap().is_view());
136-
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_owned());
136+
assert!(u.to_shape(((1, 4, 2), Order::ColumnMajor)).unwrap().is_view());
137137
}
138138

139139

@@ -150,10 +150,29 @@ fn to_shape_copy_stride() {
150150
assert!(lin2.is_owned());
151151
}
152152

153+
154+
#[test]
155+
fn to_shape_zero_len() {
156+
let v = array![[1, 2, 3, 4], [5, 6, 7, 8]];
157+
let vs = v.slice(s![.., ..0]);
158+
let lin1 = vs.to_shape(0).unwrap();
159+
assert_eq!(lin1, array![]);
160+
assert!(lin1.is_view());
161+
}
162+
153163
#[test]
154164
#[should_panic(expected = "IncompatibleShape")]
155165
fn to_shape_error1() {
156166
let data = [1, 2, 3, 4, 5, 6, 7, 8];
157167
let v = aview1(&data);
158168
let _u = v.to_shape((2, 5)).unwrap();
159169
}
170+
171+
#[test]
172+
#[should_panic(expected = "IncompatibleShape")]
173+
fn to_shape_error2() {
174+
// overflow
175+
let data = [3, 4, 5, 6, 7, 8];
176+
let v = aview1(&data);
177+
let _u = v.to_shape((2, usize::MAX)).unwrap();
178+
}

0 commit comments

Comments
 (0)