Skip to content

Commit 6b0942f

Browse files
committed
fix some bugs, make the tests pass
1 parent a8d7f79 commit 6b0942f

File tree

2 files changed

+29
-34
lines changed

2 files changed

+29
-34
lines changed

src/impl_methods.rs

-3
Original file line numberDiff line numberDiff line change
@@ -1517,9 +1517,6 @@ where
15171517
self.shape()
15181518
);
15191519

1520-
let mut size = self.raw_dim();
1521-
size[axis_index] = window_size;
1522-
15231520
AxisWindows::new(self.view(), axis, window_size)
15241521
}
15251522

src/iterators/windows.rs

+29-31
Original file line numberDiff line numberDiff line change
@@ -154,40 +154,38 @@ impl_iterator! {
154154
/// information.
155155
pub struct AxisWindows<'a, A, D>{
156156
base: ArrayView<'a, A, D>,
157-
window_size: usize,
158157
axis_idx: usize,
158+
window: D,
159+
strides: D,
159160
}
160161

161162
impl<'a, A, D: Dimension> AxisWindows<'a, A, D> {
162163
pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
163-
{
164+
{
165+
let strides = a.strides.clone();
164166
let mut base = a;
165-
let len = base.raw_dim()[axis.index()];
166-
let indices = if len < window_size {
167-
Slice::new(0, Some(0), 1)
168-
} else {
169-
Slice::new(0, Some((len - window_size + 1) as isize), 1)
170-
};
171-
base.slice_axis_inplace(axis, indices);
167+
let axis_idx = axis.index();
168+
let mut window = base.raw_dim();
169+
window[axis_idx] = window_size;
170+
171+
base.slice_each_axis_inplace(|ax_desc| {
172+
let len = ax_desc.len;
173+
let wsz = window[ax_desc.axis.index()];
174+
175+
if len < wsz {
176+
Slice::new(0, Some(0), 1)
177+
} else {
178+
Slice::new(0, Some((len - wsz + 1) as isize), 1)
179+
}
180+
});
172181

173182
AxisWindows {
174183
base,
175-
window_size,
176-
axis_idx: axis.index(),
184+
axis_idx,
185+
window,
186+
strides,
177187
}
178188
}
179-
180-
fn window(&self) -> D{
181-
let mut window = self.base.raw_dim();
182-
window[self.axis_idx] = self.window_size;
183-
window
184-
}
185-
186-
fn strides_(&self) -> D{
187-
let mut strides = D::zeros(self.base.ndim());
188-
strides.slice_mut().fill(1);
189-
strides
190-
}
191189
}
192190

193191

@@ -214,8 +212,8 @@ impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> {
214212
}
215213

216214
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
217-
ArrayView::new_(ptr, self.window(),
218-
self.strides_())
215+
ArrayView::new_(ptr, self.window.clone(),
216+
self.strides.clone())
219217
}
220218

221219
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
@@ -234,14 +232,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisWindows<'a, A, D> {
234232
let (a, b) = self.base.split_at(Axis(self.axis_idx), index);
235233
(AxisWindows {
236234
base: a,
237-
window_size: self.window_size,
238235
axis_idx: self.axis_idx,
236+
window: self.window.clone(),
237+
strides: self.strides.clone()
239238

240239
},
241240
AxisWindows {
242241
base: b,
243-
window_size: self.window_size,
244242
axis_idx: self.axis_idx,
243+
window: self.window,
244+
strides: self.strides,
245245
})
246246
}
247247

@@ -256,12 +256,10 @@ where
256256
type Item = <Self::IntoIter as Iterator>::Item;
257257
type IntoIter = WindowsIter<'a, A, D>;
258258
fn into_iter(self) -> Self::IntoIter {
259-
let window = self.window();
260-
let strides = self.strides_();
261259
WindowsIter {
262260
iter: self.base.into_elements_base(),
263-
window,
264-
strides,
261+
window: self.window,
262+
strides: self.strides,
265263
}
266264
}
267265
}

0 commit comments

Comments
 (0)