Skip to content

Commit f163e14

Browse files
authored
Merge pull request #1305 from jonasBoss/axis_windows_dimension
Change `NdProducer::Dim` of `axis_windows()` to `Ix1`
2 parents 45009ff + 21fb817 commit f163e14

File tree

5 files changed

+185
-41
lines changed

5 files changed

+185
-41
lines changed

src/impl_methods.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ use crate::iter::{
4646
AxisChunksIterMut,
4747
AxisIter,
4848
AxisIterMut,
49+
AxisWindows,
4950
ExactChunks,
5051
ExactChunksMut,
5152
IndexedIter,
@@ -1521,7 +1522,7 @@ where
15211522
/// assert_eq!(window.shape(), &[4, 3, 2]);
15221523
/// }
15231524
/// ```
1524-
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> Windows<'_, A, D>
1525+
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
15251526
where S: Data
15261527
{
15271528
let axis_index = axis.index();
@@ -1537,10 +1538,7 @@ where
15371538
self.shape()
15381539
);
15391540

1540-
let mut size = self.raw_dim();
1541-
size[axis_index] = window_size;
1542-
1543-
Windows::new(self.view(), size)
1541+
AxisWindows::new(self.view(), axis, window_size)
15441542
}
15451543

15461544
// Return (length, stride) for diagonal

src/iterators/iter.rs

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub use crate::iterators::{
1313
AxisChunksIterMut,
1414
AxisIter,
1515
AxisIterMut,
16+
AxisWindows,
1617
ExactChunks,
1718
ExactChunksIter,
1819
ExactChunksIterMut,

src/iterators/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use super::{Dimension, Ix, Ixs};
2828
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2929
pub use self::into_iter::IntoIter;
3030
pub use self::lanes::{Lanes, LanesMut};
31-
pub use self::windows::Windows;
31+
pub use self::windows::{AxisWindows, Windows};
3232

3333
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
3434

src/iterators/windows.rs

+164-35
Original file line numberDiff line numberDiff line change
@@ -41,41 +41,7 @@ impl<'a, A, D: Dimension> Windows<'a, A, D>
4141
let strides = axis_strides.into_dimension();
4242
let window_strides = a.strides.clone();
4343

44-
ndassert!(
45-
a.ndim() == window.ndim(),
46-
concat!(
47-
"Window dimension {} does not match array dimension {} ",
48-
"(with array of shape {:?})"
49-
),
50-
window.ndim(),
51-
a.ndim(),
52-
a.shape()
53-
);
54-
55-
ndassert!(
56-
a.ndim() == strides.ndim(),
57-
concat!(
58-
"Stride dimension {} does not match array dimension {} ",
59-
"(with array of shape {:?})"
60-
),
61-
strides.ndim(),
62-
a.ndim(),
63-
a.shape()
64-
);
65-
66-
let mut base = a;
67-
base.slice_each_axis_inplace(|ax_desc| {
68-
let len = ax_desc.len;
69-
let wsz = window[ax_desc.axis.index()];
70-
let stride = strides[ax_desc.axis.index()];
71-
72-
if len < wsz {
73-
Slice::new(0, Some(0), 1)
74-
} else {
75-
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
76-
}
77-
});
78-
44+
let base = build_base(a, window.clone(), strides);
7945
Windows {
8046
base: base.into_raw_view(),
8147
life: PhantomData,
@@ -160,3 +126,166 @@ impl_iterator! {
160126

161127
send_sync_read_only!(Windows);
162128
send_sync_read_only!(WindowsIter);
129+
130+
/// Window producer and iterable
131+
///
132+
/// See [`.axis_windows()`](ArrayBase::axis_windows) for more
133+
/// information.
134+
pub struct AxisWindows<'a, A, D>
135+
{
136+
base: ArrayView<'a, A, D>,
137+
axis_idx: usize,
138+
window: D,
139+
strides: D,
140+
}
141+
142+
impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
143+
{
144+
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
145+
{
146+
let window_strides = a.strides.clone();
147+
let axis_idx = axis.index();
148+
149+
let mut window = a.raw_dim();
150+
window[axis_idx] = window_size;
151+
152+
let ndim = window.ndim();
153+
let mut unit_stride = D::zeros(ndim);
154+
unit_stride.slice_mut().fill(1);
155+
156+
let base = build_base(a, window.clone(), unit_stride);
157+
AxisWindows {
158+
base,
159+
axis_idx,
160+
window,
161+
strides: window_strides,
162+
}
163+
}
164+
}
165+
166+
impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D>
167+
{
168+
type Item = ArrayView<'a, A, D>;
169+
type Dim = Ix1;
170+
type Ptr = *mut A;
171+
type Stride = isize;
172+
173+
fn raw_dim(&self) -> Ix1
174+
{
175+
Ix1(self.base.raw_dim()[self.axis_idx])
176+
}
177+
178+
fn layout(&self) -> Layout
179+
{
180+
self.base.layout()
181+
}
182+
183+
fn as_ptr(&self) -> *mut A
184+
{
185+
self.base.as_ptr() as *mut _
186+
}
187+
188+
fn contiguous_stride(&self) -> isize
189+
{
190+
self.base.contiguous_stride()
191+
}
192+
193+
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item
194+
{
195+
ArrayView::new_(ptr, self.window.clone(), self.strides.clone())
196+
}
197+
198+
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A
199+
{
200+
let mut d = D::zeros(self.base.ndim());
201+
d[self.axis_idx] = i[0];
202+
self.base.uget_ptr(&d)
203+
}
204+
205+
fn stride_of(&self, axis: Axis) -> isize
206+
{
207+
assert_eq!(axis, Axis(0));
208+
self.base.stride_of(Axis(self.axis_idx))
209+
}
210+
211+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
212+
{
213+
assert_eq!(axis, Axis(0));
214+
let (a, b) = self.base.split_at(Axis(self.axis_idx), index);
215+
(
216+
AxisWindows {
217+
base: a,
218+
axis_idx: self.axis_idx,
219+
window: self.window.clone(),
220+
strides: self.strides.clone(),
221+
},
222+
AxisWindows {
223+
base: b,
224+
axis_idx: self.axis_idx,
225+
window: self.window,
226+
strides: self.strides,
227+
},
228+
)
229+
}
230+
231+
private_impl!{}
232+
}
233+
234+
impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D>
235+
where
236+
D: Dimension,
237+
A: 'a,
238+
{
239+
type Item = <Self::IntoIter as Iterator>::Item;
240+
type IntoIter = WindowsIter<'a, A, D>;
241+
fn into_iter(self) -> Self::IntoIter
242+
{
243+
WindowsIter {
244+
iter: self.base.into_base_iter(),
245+
life: PhantomData,
246+
window: self.window,
247+
strides: self.strides,
248+
}
249+
}
250+
}
251+
252+
/// build the base array of the `Windows` and `AxisWindows` structs
253+
fn build_base<A, D>(a: ArrayView<A, D>, window: D, strides: D) -> ArrayView<A, D>
254+
where D: Dimension
255+
{
256+
ndassert!(
257+
a.ndim() == window.ndim(),
258+
concat!(
259+
"Window dimension {} does not match array dimension {} ",
260+
"(with array of shape {:?})"
261+
),
262+
window.ndim(),
263+
a.ndim(),
264+
a.shape()
265+
);
266+
267+
ndassert!(
268+
a.ndim() == strides.ndim(),
269+
concat!(
270+
"Stride dimension {} does not match array dimension {} ",
271+
"(with array of shape {:?})"
272+
),
273+
strides.ndim(),
274+
a.ndim(),
275+
a.shape()
276+
);
277+
278+
let mut base = a;
279+
base.slice_each_axis_inplace(|ax_desc| {
280+
let len = ax_desc.len;
281+
let wsz = window[ax_desc.axis.index()];
282+
let stride = strides[ax_desc.axis.index()];
283+
284+
if len < wsz {
285+
Slice::new(0, Some(0), 1)
286+
} else {
287+
Slice::new(0, Some((len - wsz + 1) as isize), stride as isize)
288+
}
289+
});
290+
base
291+
}

tests/windows.rs

+16
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,22 @@ fn test_axis_windows_3d()
278278
]);
279279
}
280280

281+
#[test]
282+
fn tests_axis_windows_3d_zips_with_1d()
283+
{
284+
let a = Array::from_iter(0..27)
285+
.into_shape_with_order((3, 3, 3))
286+
.unwrap();
287+
let mut b = Array::zeros(2);
288+
289+
Zip::from(b.view_mut())
290+
.and(a.axis_windows(Axis(1), 2))
291+
.for_each(|b, a| {
292+
*b = a.sum();
293+
});
294+
assert_eq!(b,arr1(&[207, 261]));
295+
}
296+
281297
#[test]
282298
fn test_window_neg_stride()
283299
{

0 commit comments

Comments
 (0)