Skip to content

Commit 4c26e16

Browse files
jonasBossbluss
authored andcommitted
Add AxisWindows
Move core logic of creating the Windows logic into seperate fn. fix some bugs, make the tests pass Reduce code duplication Specialize the `Windows` struct with a `variant` field to replace the `AxisWindows` struct
1 parent ec0ffa6 commit 4c26e16

File tree

4 files changed

+169
-41
lines changed

4 files changed

+169
-41
lines changed

src/impl_methods.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use crate::iter::{
4646
AxisChunksIterMut,
4747
AxisIter,
4848
AxisIterMut,
49+
AxisWindows,
4950
ExactChunks,
5051
ExactChunksMut,
5152
IndexedIter,
@@ -1521,7 +1522,7 @@ where
15211522
/// assert_eq!(window.shape(), &[4, 3, 2]);
15221523
/// }
15231524
/// ```
1524-
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> Windows<'_, A, D>
1525+
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
15251526
where S: Data
15261527
{
15271528
let axis_index = axis.index();
@@ -1537,10 +1538,7 @@ where
15371538
self.shape()
15381539
);
15391540

1540-
let mut size = self.raw_dim();
1541-
size[axis_index] = window_size;
1542-
1543-
Windows::new(self.view(), size)
1541+
AxisWindows::new(self.view(), axis, window_size)
15441542
}
15451543

15461544
// Return (length, stride) for diagonal

src/iterators/iter.rs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub use crate::iterators::{
1313
AxisChunksIterMut,
1414
AxisIter,
1515
AxisIterMut,
16+
AxisWindows,
1617
ExactChunks,
1718
ExactChunksIter,
1819
ExactChunksIterMut,

src/iterators/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use super::{Dimension, Ix, Ixs};
2828
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2929
pub use self::into_iter::IntoIter;
3030
pub use self::lanes::{Lanes, LanesMut};
31-
pub use self::windows::Windows;
31+
pub use self::windows::{AxisWindows, Windows};
3232

3333
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
3434

src/iterators/windows.rs

+164-35
Original file line numberDiff line numberDiff line change
@@ -41,41 +41,7 @@ impl<'a, A, D: Dimension> Windows<'a, A, D>
4141
let strides = axis_strides.into_dimension();
4242
let window_strides = a.strides.clone();
4343

44-
ndassert!(
45-
a.ndim() == window.ndim(),
46-
concat!(
47-
"Window dimension {} does not match array dimension {} ",
48-
"(with array of shape {:?})"
49-
),
50-
window.ndim(),
51-
a.ndim(),
52-
a.shape()
53-
);
54-
55-
ndassert!(
56-
a.ndim() == strides.ndim(),
57-
concat!(
58-
"Stride dimension {} does not match array dimension {} ",
59-
"(with array of shape {:?})"
60-
),
61-
strides.ndim(),
62-
a.ndim(),
63-
a.shape()
64-
);
65-
66-
let mut base = a;
67-
base.slice_each_axis_inplace(|ax_desc| {
68-
let len = ax_desc.len;
69-
let wsz = window[ax_desc.axis.index()];
70-
let stride = strides[ax_desc.axis.index()];
71-
72-
if len < wsz {
73-
Slice::new(0, Some(0), 1)
74-
} else {
75-
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
76-
}
77-
});
78-
44+
let base = build_base(a, window.clone(), strides);
7945
Windows {
8046
base: base.into_raw_view(),
8147
life: PhantomData,
@@ -160,3 +126,166 @@ impl_iterator! {
160126

161127
send_sync_read_only!(Windows);
162128
send_sync_read_only!(WindowsIter);
129+
130+
/// Window producer and iterable
131+
///
132+
/// See [`.axis_windows()`](ArrayBase::axis_windows) for more
133+
/// information.
134+
pub struct AxisWindows<'a, A, D>
135+
{
136+
base: ArrayView<'a, A, D>,
137+
axis_idx: usize,
138+
window: D,
139+
strides: D,
140+
}
141+
142+
impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
143+
{
144+
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
145+
{
146+
let window_strides = a.strides.clone();
147+
let axis_idx = axis.index();
148+
149+
let mut window = a.raw_dim();
150+
window[axis_idx] = window_size;
151+
152+
let ndim = window.ndim();
153+
let mut unit_stride = D::zeros(ndim);
154+
unit_stride.slice_mut().fill(1);
155+
156+
let base = build_base(a, window.clone(), unit_stride);
157+
AxisWindows {
158+
base,
159+
axis_idx,
160+
window,
161+
strides: window_strides,
162+
}
163+
}
164+
}
165+
166+
impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D>
167+
{
168+
type Item = ArrayView<'a, A, D>;
169+
type Dim = Ix1;
170+
type Ptr = *mut A;
171+
type Stride = isize;
172+
173+
fn raw_dim(&self) -> Ix1
174+
{
175+
Ix1(self.base.raw_dim()[self.axis_idx])
176+
}
177+
178+
fn layout(&self) -> Layout
179+
{
180+
self.base.layout()
181+
}
182+
183+
fn as_ptr(&self) -> *mut A
184+
{
185+
self.base.as_ptr() as *mut _
186+
}
187+
188+
fn contiguous_stride(&self) -> isize
189+
{
190+
self.base.contiguous_stride()
191+
}
192+
193+
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item
194+
{
195+
ArrayView::new_(ptr, self.window.clone(), self.strides.clone())
196+
}
197+
198+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A
199+
{
200+
let mut d = D::zeros(self.base.ndim());
201+
d[self.axis_idx] = i[0];
202+
self.base.uget_ptr(&d)
203+
}
204+
205+
fn stride_of(&self, axis: Axis) -> isize
206+
{
207+
assert_eq!(axis, Axis(0));
208+
self.base.stride_of(Axis(self.axis_idx))
209+
}
210+
211+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
212+
{
213+
assert_eq!(axis, Axis(0));
214+
let (a, b) = self.base.split_at(Axis(self.axis_idx), index);
215+
(
216+
AxisWindows {
217+
base: a,
218+
axis_idx: self.axis_idx,
219+
window: self.window.clone(),
220+
strides: self.strides.clone(),
221+
},
222+
AxisWindows {
223+
base: b,
224+
axis_idx: self.axis_idx,
225+
window: self.window,
226+
strides: self.strides,
227+
},
228+
)
229+
}
230+
231+
private_impl!{}
232+
}
233+
234+
impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D>
235+
where
236+
D: Dimension,
237+
A: 'a,
238+
{
239+
type Item = <Self::IntoIter as Iterator>::Item;
240+
type IntoIter = WindowsIter<'a, A, D>;
241+
fn into_iter(self) -> Self::IntoIter
242+
{
243+
WindowsIter {
244+
iter: self.base.into_base_iter(),
245+
life: PhantomData,
246+
window: self.window,
247+
strides: self.strides,
248+
}
249+
}
250+
}
251+
252+
/// build the base array of the `Windows` and `AxisWindows` structs
253+
fn build_base<'a, A, D>(a: ArrayView<'a, A, D>, window: D, strides: D) -> ArrayView<'a, A, D>
254+
where D: Dimension
255+
{
256+
ndassert!(
257+
a.ndim() == window.ndim(),
258+
concat!(
259+
"Window dimension {} does not match array dimension {} ",
260+
"(with array of shape {:?})"
261+
),
262+
window.ndim(),
263+
a.ndim(),
264+
a.shape()
265+
);
266+
267+
ndassert!(
268+
a.ndim() == strides.ndim(),
269+
concat!(
270+
"Stride dimension {} does not match array dimension {} ",
271+
"(with array of shape {:?})"
272+
),
273+
strides.ndim(),
274+
a.ndim(),
275+
a.shape()
276+
);
277+
278+
let mut base = a;
279+
base.slice_each_axis_inplace(|ax_desc| {
280+
let len = ax_desc.len;
281+
let wsz = window[ax_desc.axis.index()];
282+
let stride = strides[ax_desc.axis.index()];
283+
284+
if len < wsz {
285+
Slice::new(0, Some(0), 1)
286+
} else {
287+
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
288+
}
289+
});
290+
base
291+
}

0 commit comments

Comments
 (0)