Skip to content

Commit a8d7f79

Browse files
committed
create AxisWindows struct and implement it
1 parent 9447328 commit a8d7f79

File tree

4 files changed

+123
-5
lines changed

4 files changed

+123
-5
lines changed

src/impl_methods.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::zip::{IntoNdProducer, Zip};
3434

3535
use crate::iter::{
3636
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
37-
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows,
37+
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, AxisWindows
3838
};
3939
use crate::slice::{MultiSliceArg, SliceArg};
4040
use crate::stacking::concatenate;
@@ -1500,7 +1500,7 @@ where
15001500
/// assert_eq!(window.shape(), &[4, 3, 2]);
15011501
/// }
15021502
/// ```
1503-
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> Windows<'_, A, D>
1503+
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
15041504
where
15051505
S: Data,
15061506
{
@@ -1520,7 +1520,7 @@ where
15201520
let mut size = self.raw_dim();
15211521
size[axis_index] = window_size;
15221522

1523-
Windows::new(self.view(), size)
1523+
AxisWindows::new(self.view(), axis, window_size)
15241524
}
15251525

15261526
// Return (length, stride) for diagonal

src/iterators/iter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ pub use crate::indexes::{Indices, IndicesIter};
1111
pub use crate::iterators::{
1212
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksIter,
1313
ExactChunksIterMut, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes,
14-
LanesIter, LanesIterMut, LanesMut, Windows,
14+
LanesIter, LanesIterMut, LanesMut, Windows, AxisWindows
1515
};

src/iterators/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use super::{Dimension, Ix, Ixs};
2626

2727
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2828
pub use self::lanes::{Lanes, LanesMut};
29-
pub use self::windows::Windows;
29+
pub use self::windows::{Windows, AxisWindows};
3030
pub use self::into_iter::IntoIter;
3131

3232
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};

src/iterators/windows.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,121 @@ impl_iterator! {
147147
}
148148
}
149149
}
150+
151+
/// Window producer and iterable
152+
///
153+
/// See [`.axis_windows()`](ArrayBase::axis_windows) for more
154+
/// information.
155+
pub struct AxisWindows<'a, A, D>{
156+
base: ArrayView<'a, A, D>,
157+
window_size: usize,
158+
axis_idx: usize,
159+
}
160+
161+
impl<'a, A, D: Dimension> AxisWindows<'a, A, D> {
162+
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
163+
{
164+
let mut base = a;
165+
let len = base.raw_dim()[axis.index()];
166+
let indices = if len < window_size {
167+
Slice::new(0, Some(0), 1)
168+
} else {
169+
Slice::new(0, Some((len - window_size + 1) as isize), 1)
170+
};
171+
base.slice_axis_inplace(axis, indices);
172+
173+
AxisWindows {
174+
base,
175+
window_size,
176+
axis_idx: axis.index(),
177+
}
178+
}
179+
180+
fn window(&self) -> D{
181+
let mut window = self.base.raw_dim();
182+
window[self.axis_idx] = self.window_size;
183+
window
184+
}
185+
186+
fn strides_(&self) -> D{
187+
let mut strides = D::zeros(self.base.ndim());
188+
strides.slice_mut().fill(1);
189+
strides
190+
}
191+
}
192+
193+
194+
impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> {
195+
type Item = ArrayView<'a, A, D>;
196+
type Dim = Ix1;
197+
type Ptr = *mut A;
198+
type Stride = isize;
199+
200+
fn raw_dim(&self) -> Ix1 {
201+
Ix1(self.base.raw_dim()[self.axis_idx])
202+
}
203+
204+
fn layout(&self) -> Layout {
205+
self.base.layout()
206+
}
207+
208+
fn as_ptr(&self) -> *mut A {
209+
self.base.as_ptr() as *mut _
210+
}
211+
212+
fn contiguous_stride(&self) -> isize {
213+
self.base.contiguous_stride()
214+
}
215+
216+
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
217+
ArrayView::new_(ptr, self.window(),
218+
self.strides_())
219+
}
220+
221+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
222+
let mut d = D::zeros(self.base.ndim());
223+
d[self.axis_idx] = i[0];
224+
self.base.uget_ptr(&d)
225+
}
226+
227+
fn stride_of(&self, axis: Axis) -> isize {
228+
assert_eq!(axis, Axis(0));
229+
self.base.stride_of(Axis(self.axis_idx))
230+
}
231+
232+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
233+
assert_eq!(axis, Axis(0));
234+
let (a, b) = self.base.split_at(Axis(self.axis_idx), index);
235+
(AxisWindows {
236+
base: a,
237+
window_size: self.window_size,
238+
axis_idx: self.axis_idx,
239+
240+
},
241+
AxisWindows {
242+
base: b,
243+
window_size: self.window_size,
244+
axis_idx: self.axis_idx,
245+
})
246+
}
247+
248+
private_impl!{}
249+
}
250+
251+
impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D>
252+
where
253+
D: Dimension,
254+
A: 'a,
255+
{
256+
type Item = <Self::IntoIter as Iterator>::Item;
257+
type IntoIter = WindowsIter<'a, A, D>;
258+
fn into_iter(self) -> Self::IntoIter {
259+
let window = self.window();
260+
let strides = self.strides_();
261+
WindowsIter {
262+
iter: self.base.into_elements_base(),
263+
window,
264+
strides,
265+
}
266+
}
267+
}

0 commit comments

Comments
 (0)