Skip to content

Commit d5d7c72

Browse files
committed
Reduce code duplication
Specialize the `Windows` struct with a `variant` field to replace the `AxisWindows` struct
1 parent d3460e8 commit d5d7c72

File tree

4 files changed

+58
-96
lines changed

4 files changed

+58
-96
lines changed

src/impl_methods.rs

+12-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use rawpointer::PointerExt;
1414

1515
use crate::imp_prelude::*;
1616

17+
use crate::iterators::AxisWindow;
18+
use crate::iterators::GeneralWindow;
1719
use crate::{arraytraits, DimMax};
1820
use crate::argument_traits::AssignElem;
1921
use crate::dimension;
@@ -34,7 +36,7 @@ use crate::zip::{IntoNdProducer, Zip};
3436

3537
use crate::iter::{
3638
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
37-
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows, AxisWindows
39+
IndexedIter, IndexedIterMut, Iter, IterMut, Lanes, LanesMut, Windows
3840
};
3941
use crate::slice::{MultiSliceArg, SliceArg};
4042
use crate::stacking::concatenate;
@@ -1419,12 +1421,12 @@ where
14191421
/// that fit into the array's shape.
14201422
///
14211423
/// This is essentially equivalent to [`.windows_with_stride()`] with unit stride.
1422-
pub fn windows<E>(&self, window_size: E) -> Windows<'_, A, D>
1424+
pub fn windows<E>(&self, window_size: E) -> Windows<'_, A, D, GeneralWindow>
14231425
where
14241426
E: IntoDimension<Dim = D>,
14251427
S: Data,
14261428
{
1427-
Windows::new(self.view(), window_size)
1429+
Windows::new(self.view(), window_size, GeneralWindow)
14281430
}
14291431

14301432
/// Return a window producer and iterable.
@@ -1471,12 +1473,12 @@ where
14711473
/// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃
14721474
/// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┴─────┺━━━━━┻━━━━━┛
14731475
/// ```
1474-
pub fn windows_with_stride<E>(&self, window_size: E, stride: E) -> Windows<'_, A, D>
1476+
pub fn windows_with_stride<E>(&self, window_size: E, stride: E) -> Windows<'_, A, D, GeneralWindow>
14751477
where
14761478
E: IntoDimension<Dim = D>,
14771479
S: Data,
14781480
{
1479-
Windows::new_with_stride(self.view(), window_size, stride)
1481+
Windows::new_with_stride(self.view(), window_size, stride, GeneralWindow)
14801482
}
14811483

14821484
/// Returns a producer which traverses over all windows of a given length along an axis.
@@ -1500,7 +1502,7 @@ where
15001502
/// assert_eq!(window.shape(), &[4, 3, 2]);
15011503
/// }
15021504
/// ```
1503-
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
1505+
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> Windows<'_, A, D, AxisWindow>
15041506
where
15051507
S: Data,
15061508
{
@@ -1517,7 +1519,10 @@ where
15171519
self.shape()
15181520
);
15191521

1520-
AxisWindows::new(self.view(), axis, window_size)
1522+
let mut size = self.raw_dim();
1523+
size[axis_index] = window_size;
1524+
1525+
Windows::new(self.view(), size, AxisWindow{index: axis_index})
15211526
}
15221527

15231528
// Return (length, stride) for diagonal

src/iterators/iter.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ pub use crate::indexes::{Indices, IndicesIter};
1111
pub use crate::iterators::{
1212
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksIter,
1313
ExactChunksIterMut, ExactChunksMut, IndexedIter, IndexedIterMut, Iter, IterMut, Lanes,
14-
LanesIter, LanesIterMut, LanesMut, Windows, AxisWindows
14+
LanesIter, LanesIterMut, LanesMut, Windows
1515
};

src/iterators/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ use super::{Dimension, Ix, Ixs};
2626

2727
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
2828
pub use self::lanes::{Lanes, LanesMut};
29-
pub use self::windows::{Windows, AxisWindows};
29+
pub use self::windows::Windows;
30+
pub(crate) use self::windows::{GeneralWindow, AxisWindow};
3031
pub use self::into_iter::IntoIter;
3132

3233
use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};

src/iterators/windows.rs

+43-87
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@ use crate::Layout;
55
use crate::NdProducer;
66
use crate::Slice;
77

8+
#[derive(Clone)]
9+
pub struct GeneralWindow;
10+
#[derive(Clone)]
11+
pub struct AxisWindow{pub(crate) index: usize}
12+
813
/// Window producer and iterable
914
///
1015
/// See [`.windows()`](ArrayBase::windows) for more
1116
/// information.
12-
pub struct Windows<'a, A, D> {
17+
pub struct Windows<'a, A, D, V> {
1318
base: ArrayView<'a, A, D>,
1419
window: D,
1520
strides: D,
21+
variant: V,
1622
}
1723

18-
impl<'a, A, D: Dimension> Windows<'a, A, D> {
19-
pub(crate) fn new<E>(a: ArrayView<'a, A, D>, window_size: E) -> Self
24+
impl<'a, A, D: Dimension, V> Windows<'a, A, D, V> {
25+
pub(crate) fn new<E>(a: ArrayView<'a, A, D>, window_size: E, variant: V) -> Self
2026
where
2127
E: IntoDimension<Dim = D>,
2228
{
@@ -26,10 +32,15 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
2632
let mut unit_stride = D::zeros(ndim);
2733
unit_stride.slice_mut().fill(1);
2834

29-
Windows::new_with_stride(a, window, unit_stride)
35+
Windows::new_with_stride(a, window, unit_stride, variant)
3036
}
3137

32-
pub(crate) fn new_with_stride<E>(a: ArrayView<'a, A, D>, window_size: E, axis_strides: E) -> Self
38+
pub(crate) fn new_with_stride<E>(
39+
a: ArrayView<'a, A, D>,
40+
window_size: E,
41+
axis_strides: E,
42+
variant: V,
43+
) -> Self
3344
where
3445
E: IntoDimension<Dim = D>,
3546
{
@@ -77,6 +88,7 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> {
7788
base,
7889
window,
7990
strides: window_strides,
91+
variant,
8092
}
8193
}
8294
}
@@ -88,8 +100,9 @@ impl_ndproducer! {
88100
base,
89101
window,
90102
strides,
103+
variant,
91104
}
92-
Windows<'a, A, D> {
105+
Windows<'a, A, D, GeneralWindow> {
93106
type Item = ArrayView<'a, A, D>;
94107
type Dim = D;
95108

@@ -100,7 +113,7 @@ impl_ndproducer! {
100113
}
101114
}
102115

103-
impl<'a, A, D> IntoIterator for Windows<'a, A, D>
116+
impl<'a, A, D, V> IntoIterator for Windows<'a, A, D, V>
104117
where
105118
D: Dimension,
106119
A: 'a,
@@ -148,55 +161,14 @@ impl_iterator! {
148161
}
149162
}
150163

151-
/// Window producer and iterable
152-
///
153-
/// See [`.axis_windows()`](ArrayBase::axis_windows) for more
154-
/// information.
155-
pub struct AxisWindows<'a, A, D>{
156-
base: ArrayView<'a, A, D>,
157-
axis_idx: usize,
158-
window: D,
159-
strides: D,
160-
}
161-
162-
impl<'a, A, D: Dimension> AxisWindows<'a, A, D> {
163-
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
164-
{
165-
let strides = a.strides.clone();
166-
let mut base = a;
167-
let axis_idx = axis.index();
168-
let mut window = base.raw_dim();
169-
window[axis_idx] = window_size;
170-
171-
base.slice_each_axis_inplace(|ax_desc| {
172-
let len = ax_desc.len;
173-
let wsz = window[ax_desc.axis.index()];
174-
175-
if len < wsz {
176-
Slice::new(0, Some(0), 1)
177-
} else {
178-
Slice::new(0, Some((len - wsz + 1) as isize), 1)
179-
}
180-
});
181-
182-
AxisWindows {
183-
base,
184-
axis_idx,
185-
window,
186-
strides,
187-
}
188-
}
189-
}
190-
191-
192-
impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> {
164+
impl<'a, A, D: Dimension> NdProducer for Windows<'a, A, D, AxisWindow> {
193165
type Item = ArrayView<'a, A, D>;
194166
type Dim = Ix1;
195167
type Ptr = *mut A;
196168
type Stride = isize;
197169

198170
fn raw_dim(&self) -> Ix1 {
199-
Ix1(self.base.raw_dim()[self.axis_idx])
171+
Ix1(self.base.raw_dim()[self.variant.index])
200172
}
201173

202174
fn layout(&self) -> Layout {
@@ -212,54 +184,38 @@ impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> {
212184
}
213185

214186
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
215-
ArrayView::new_(ptr, self.window.clone(),
216-
self.strides.clone())
187+
ArrayView::new_(ptr, self.window.clone(), self.strides.clone())
217188
}
218189

219190
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
220191
let mut d = D::zeros(self.base.ndim());
221-
d[self.axis_idx] = i[0];
192+
d[self.variant.index] = i[0];
222193
self.base.uget_ptr(&d)
223194
}
224195

225196
fn stride_of(&self, axis: Axis) -> isize {
226197
assert_eq!(axis, Axis(0));
227-
self.base.stride_of(Axis(self.axis_idx))
198+
self.base.stride_of(Axis(self.variant.index))
228199
}
229200

230201
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
231202
assert_eq!(axis, Axis(0));
232-
let (a, b) = self.base.split_at(Axis(self.axis_idx), index);
233-
(AxisWindows {
234-
base: a,
235-
axis_idx: self.axis_idx,
236-
window: self.window.clone(),
237-
strides: self.strides.clone()
238-
239-
},
240-
AxisWindows {
241-
base: b,
242-
axis_idx: self.axis_idx,
243-
window: self.window,
244-
strides: self.strides,
245-
})
246-
}
247-
248-
private_impl!{}
249-
}
250-
251-
impl<'a, A, D> IntoIterator for AxisWindows<'a, A, D>
252-
where
253-
D: Dimension,
254-
A: 'a,
255-
{
256-
type Item = <Self::IntoIter as Iterator>::Item;
257-
type IntoIter = WindowsIter<'a, A, D>;
258-
fn into_iter(self) -> Self::IntoIter {
259-
WindowsIter {
260-
iter: self.base.into_elements_base(),
261-
window: self.window,
262-
strides: self.strides,
263-
}
264-
}
203+
let (a, b) = self.base.split_at(Axis(self.variant.index), index);
204+
(
205+
Windows {
206+
base: a,
207+
window: self.window.clone(),
208+
strides: self.strides.clone(),
209+
variant: self.variant.clone(),
210+
},
211+
Windows {
212+
base: b,
213+
window: self.window,
214+
strides: self.strides,
215+
variant: self.variant,
216+
},
217+
)
218+
}
219+
220+
private_impl! {}
265221
}

0 commit comments

Comments
 (0)