Skip to content

Commit f220572

Browse files
authored
Merge pull request #885 from SparrowLii/negative_strides
Fix memory continuity judgment when stride is negative
2 parents 39bb1c5 + e6a4f10 commit f220572

File tree

11 files changed

+208
-51
lines changed

11 files changed

+208
-51
lines changed

blas-tests/tests/oper.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ where
173173
S2: Data<Elem = A>,
174174
{
175175
let ((m, _), k) = (lhs.dim(), rhs.dim());
176-
reference_mat_mul(lhs, &rhs.to_owned().into_shape((k, 1)).unwrap())
176+
reference_mat_mul(lhs, &rhs.as_standard_layout().into_shape((k, 1)).unwrap())
177177
.into_shape(m)
178178
.unwrap()
179179
}
@@ -186,7 +186,7 @@ where
186186
S2: Data<Elem = A>,
187187
{
188188
let (m, (_, n)) = (lhs.dim(), rhs.dim());
189-
reference_mat_mul(&lhs.to_owned().into_shape((1, m)).unwrap(), rhs)
189+
reference_mat_mul(&lhs.as_standard_layout().into_shape((1, m)).unwrap(), rhs)
190190
.into_shape(n)
191191
.unwrap()
192192
}

src/dimension/dimension_trait.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -287,17 +287,16 @@ pub trait Dimension:
287287
return true;
288288
}
289289
if dim.ndim() == 1 {
290-
return false;
290+
return strides[0] as isize == -1;
291291
}
292292
let order = strides._fastest_varying_stride_order();
293293
let strides = strides.slice();
294294

295-
// FIXME: Negative strides
296295
let dim_slice = dim.slice();
297296
let mut cstride = 1;
298297
for &i in order.slice() {
299298
// a dimension of length 1 can have unequal strides
300-
if dim_slice[i] != 1 && strides[i] != cstride {
299+
if dim_slice[i] != 1 && (strides[i] as isize).abs() as usize != cstride {
301300
return false;
302301
}
303302
cstride *= dim_slice[i];
@@ -308,16 +307,17 @@ pub trait Dimension:
308307
/// Return the axis ordering corresponding to the fastest variation
309308
/// (in ascending order).
310309
///
311-
/// Assumes that no stride value appears twice. This cannot yield the correct
312-
/// result the strides are not positive.
310+
/// Assumes that no stride value appears twice.
313311
#[doc(hidden)]
314312
fn _fastest_varying_stride_order(&self) -> Self {
315313
let mut indices = self.clone();
316314
for (i, elt) in enumerate(indices.slice_mut()) {
317315
*elt = i;
318316
}
319317
let strides = self.slice();
320-
indices.slice_mut().sort_by_key(|&i| strides[i]);
318+
indices
319+
.slice_mut()
320+
.sort_by_key(|&i| (strides[i] as isize).abs());
321321
indices
322322
}
323323

@@ -646,7 +646,7 @@ impl Dimension for Dim<[Ix; 2]> {
646646

647647
#[inline]
648648
fn _fastest_varying_stride_order(&self) -> Self {
649-
if get!(self, 0) as Ixs <= get!(self, 1) as Ixs {
649+
if (get!(self, 0) as Ixs).abs() <= (get!(self, 1) as Ixs).abs() {
650650
Ix2(0, 1)
651651
} else {
652652
Ix2(1, 0)
@@ -806,7 +806,7 @@ impl Dimension for Dim<[Ix; 3]> {
806806
let mut order = Ix3(0, 1, 2);
807807
macro_rules! swap {
808808
($stride:expr, $order:expr, $x:expr, $y:expr) => {
809-
if $stride[$x] > $stride[$y] {
809+
if ($stride[$x] as isize).abs() > ($stride[$y] as isize).abs() {
810810
$stride.swap($x, $y);
811811
$order.ixm().swap($x, $y);
812812
}

src/dimension/mod.rs

+46-21
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,12 @@ pub fn stride_offset(n: Ix, stride: Ix) -> isize {
4646
/// There is overlap if, when iterating through the dimensions in order of
4747
/// increasing stride, the current stride is less than or equal to the maximum
4848
/// possible offset along the preceding axes. (Axes of length ≤1 are ignored.)
49-
///
50-
/// The current implementation assumes that strides of axes with length > 1 are
51-
/// nonnegative. Additionally, it does not check for overflow.
5249
pub fn dim_stride_overlap<D: Dimension>(dim: &D, strides: &D) -> bool {
5350
let order = strides._fastest_varying_stride_order();
5451
let mut sum_prev_offsets = 0;
5552
for &index in order.slice() {
5653
let d = dim[index];
57-
let s = strides[index] as isize;
54+
let s = (strides[index] as isize).abs();
5855
match d {
5956
0 => return false,
6057
1 => {}
@@ -210,11 +207,7 @@ where
210207
///
211208
/// 2. The product of non-zero axis lengths must not exceed `isize::MAX`.
212209
///
213-
/// 3. For axes with length > 1, the stride must be nonnegative. This is
214-
/// necessary to make sure the pointer cannot move backwards outside the
215-
/// slice. For axes with length ≤ 1, the stride can be anything.
216-
///
217-
/// 4. If the array will be empty (any axes are zero-length), the difference
210+
/// 3. If the array will be empty (any axes are zero-length), the difference
218211
/// between the least address and greatest address accessible by moving
219212
/// along all axes must be ≤ `data.len()`. (It's fine in this case to move
220213
/// one byte past the end of the slice since the pointers will be offset but
@@ -225,13 +218,19 @@ where
225218
/// `data.len()`. This and #3 ensure that all dereferenceable pointers point
226219
/// to elements within the slice.
227220
///
228-
/// 5. The strides must not allow any element to be referenced by two different
221+
/// 4. The strides must not allow any element to be referenced by two different
229222
/// indices.
230223
///
231224
/// Note that since slices cannot contain more than `isize::MAX` bytes,
232225
/// condition 4 is sufficient to guarantee that the absolute difference in
233226
/// units of `A` and in units of bytes between the least address and greatest
234227
/// address accessible by moving along all axes does not exceed `isize::MAX`.
228+
///
229+
/// Warning: This function is sufficient to check the invariants of ArrayBase only
230+
/// if the pointer to the first element of the array is chosen such that the element
231+
/// with the smallest memory address is at the start of data. (In other words, the
232+
/// pointer to the first element of the array must be computed using offset_from_ptr_to_memory
233+
/// so that negative strides are correctly handled.)
235234
pub(crate) fn can_index_slice<A, D: Dimension>(
236235
data: &[A],
237236
dim: &D,
@@ -248,7 +247,7 @@ fn can_index_slice_impl<D: Dimension>(
248247
dim: &D,
249248
strides: &D,
250249
) -> Result<(), ShapeError> {
251-
// Check condition 4.
250+
// Check condition 3.
252251
let is_empty = dim.slice().iter().any(|&d| d == 0);
253252
if is_empty && max_offset > data_len {
254253
return Err(from_kind(ErrorKind::OutOfBounds));
@@ -257,15 +256,7 @@ fn can_index_slice_impl<D: Dimension>(
257256
return Err(from_kind(ErrorKind::OutOfBounds));
258257
}
259258

260-
// Check condition 3.
261-
for (&d, &s) in izip!(dim.slice(), strides.slice()) {
262-
let s = s as isize;
263-
if d > 1 && s < 0 {
264-
return Err(from_kind(ErrorKind::Unsupported));
265-
}
266-
}
267-
268-
// Check condition 5.
259+
// Check condition 4.
269260
if !is_empty && dim_stride_overlap(dim, strides) {
270261
return Err(from_kind(ErrorKind::Unsupported));
271262
}
@@ -289,6 +280,19 @@ pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option
289280
Some(offset)
290281
}
291282

283+
/// Checks if strides are non-negative.
284+
pub fn strides_non_negative<D>(strides: &D) -> Result<(), ShapeError>
285+
where
286+
D: Dimension,
287+
{
288+
for &stride in strides.slice() {
289+
if (stride as isize) < 0 {
290+
return Err(from_kind(ErrorKind::Unsupported));
291+
}
292+
}
293+
Ok(())
294+
}
295+
292296
/// Implementation-specific extensions to `Dimension`
293297
pub trait DimensionExt {
294298
// note: many extensions go in the main trait if they need to be special-
@@ -394,6 +398,19 @@ fn to_abs_slice(axis_len: usize, slice: Slice) -> (usize, usize, isize) {
394398
(start, end, step)
395399
}
396400

401+
/// This function computes the offset from the logically first element to the first element in
402+
/// memory of the array. The result is always <= 0.
403+
pub fn offset_from_ptr_to_memory<D: Dimension>(dim: &D, strides: &D) -> isize {
404+
let offset = izip!(dim.slice(), strides.slice()).fold(0, |_offset, (d, s)| {
405+
if (*s as isize) < 0 {
406+
_offset + *s as isize * (*d as isize - 1)
407+
} else {
408+
_offset
409+
}
410+
});
411+
offset
412+
}
413+
397414
/// Modify dimension, stride and return data pointer offset
398415
///
399416
/// **Panics** if stride is 0 or if any index is out of bounds.
@@ -693,13 +710,21 @@ mod test {
693710
let dim = (2, 3, 2).into_dimension();
694711
let strides = (5, 2, 1).into_dimension();
695712
assert!(super::dim_stride_overlap(&dim, &strides));
713+
let strides = (-5isize as usize, 2, -1isize as usize).into_dimension();
714+
assert!(super::dim_stride_overlap(&dim, &strides));
696715
let strides = (6, 2, 1).into_dimension();
697716
assert!(!super::dim_stride_overlap(&dim, &strides));
717+
let strides = (6, -2isize as usize, 1).into_dimension();
718+
assert!(!super::dim_stride_overlap(&dim, &strides));
698719
let strides = (6, 0, 1).into_dimension();
699720
assert!(super::dim_stride_overlap(&dim, &strides));
721+
let strides = (-6isize as usize, 0, 1).into_dimension();
722+
assert!(super::dim_stride_overlap(&dim, &strides));
700723
let dim = (2, 2).into_dimension();
701724
let strides = (3, 2).into_dimension();
702725
assert!(!super::dim_stride_overlap(&dim, &strides));
726+
let strides = (3, -2isize as usize).into_dimension();
727+
assert!(!super::dim_stride_overlap(&dim, &strides));
703728
}
704729

705730
#[test]
@@ -736,7 +761,7 @@ mod test {
736761
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
737762
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
738763
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
739-
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap_err();
764+
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
740765
}
741766

742767
#[test]

src/impl_constructors.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use alloc::vec;
1919
use alloc::vec::Vec;
2020

2121
use crate::dimension;
22+
use crate::dimension::offset_from_ptr_to_memory;
2223
use crate::error::{self, ShapeError};
2324
use crate::extension::nonnull::nonnull_from_vec_data;
2425
use crate::imp_prelude::*;
@@ -30,6 +31,7 @@ use crate::iterators::to_vec_mapped;
3031
use crate::StrideShape;
3132
#[cfg(feature = "std")]
3233
use crate::{geomspace, linspace, logspace};
34+
use rawpointer::PointerExt;
3335

3436

3537
/// # Constructor Methods for Owned Arrays
@@ -442,7 +444,8 @@ where
442444
///
443445
/// 2. The product of non-zero axis lengths must not exceed `isize::MAX`.
444446
///
445-
/// 3. For axes with length > 1, the stride must be nonnegative.
447+
/// 3. For axes with length > 1, the pointer cannot move outside the
448+
/// slice.
446449
///
447450
/// 4. If the array will be empty (any axes are zero-length), the
448451
/// difference between the least address and greatest address accessible
@@ -468,7 +471,7 @@ where
468471
// debug check for issues that indicates wrong use of this constructor
469472
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
470473
ArrayBase {
471-
ptr: nonnull_from_vec_data(&mut v),
474+
ptr: nonnull_from_vec_data(&mut v).offset(-offset_from_ptr_to_memory(&dim, &strides)),
472475
data: DataOwned::new(v),
473476
strides,
474477
dim,
@@ -494,7 +497,7 @@ where
494497
///
495498
/// This constructor is limited to elements where `A: Copy` (no destructors)
496499
/// to avoid users shooting themselves too hard in the foot.
497-
///
500+
///
498501
/// (Also note that the constructors `from_shape_vec` and
499502
/// `from_shape_vec_unchecked` allow the user yet more control, in the sense
500503
/// that Arrays can be created from arbitrary vectors.)

src/impl_methods.rs

+22-14
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ use crate::arraytraits;
1818
use crate::dimension;
1919
use crate::dimension::IntoDimension;
2020
use crate::dimension::{
21-
abs_index, axes_of, do_slice, merge_axes, size_of_shape_checked, stride_offset, Axes,
21+
abs_index, axes_of, do_slice, merge_axes, offset_from_ptr_to_memory, size_of_shape_checked,
22+
stride_offset, Axes,
2223
};
2324
use crate::error::{self, ErrorKind, ShapeError};
2425
use crate::math_cell::MathCell;
@@ -169,12 +170,12 @@ where
169170

170171
/// Return an uniquely owned copy of the array.
171172
///
172-
/// If the input array is contiguous and its strides are positive, then the
173-
/// output array will have the same memory layout. Otherwise, the layout of
174-
/// the output array is unspecified. If you need a particular layout, you
175-
/// can allocate a new array with the desired memory layout and
176-
/// [`.assign()`](#method.assign) the data. Alternatively, you can collect
177-
/// an iterator, like this for a result in standard layout:
173+
/// If the input array is contiguous, then the output array will have the same
174+
/// memory layout. Otherwise, the layout of the output array is unspecified.
175+
/// If you need a particular layout, you can allocate a new array with the
176+
/// desired memory layout and [`.assign()`](#method.assign) the data.
177+
/// Alternatively, you can collectan iterator, like this for a result in
178+
/// standard layout:
178179
///
179180
/// ```
180181
/// # use ndarray::prelude::*;
@@ -1296,9 +1297,6 @@ where
12961297
}
12971298

12981299
/// Return true if the array is known to be contiguous.
1299-
///
1300-
/// Will detect c- and f-contig arrays correctly, but otherwise
1301-
/// There are some false negatives.
13021300
pub(crate) fn is_contiguous(&self) -> bool {
13031301
D::is_contiguous(&self.dim, &self.strides)
13041302
}
@@ -1420,14 +1418,18 @@ where
14201418
///
14211419
/// If this function returns `Some(_)`, then the elements in the slice
14221420
/// have whatever order the elements have in memory.
1423-
///
1424-
/// Implementation notes: Does not yet support negatively strided arrays.
14251421
pub fn as_slice_memory_order(&self) -> Option<&[A]>
14261422
where
14271423
S: Data,
14281424
{
14291425
if self.is_contiguous() {
1430-
unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) }
1426+
let offset = offset_from_ptr_to_memory(&self.dim, &self.strides);
1427+
unsafe {
1428+
Some(slice::from_raw_parts(
1429+
self.ptr.offset(offset).as_ptr(),
1430+
self.len(),
1431+
))
1432+
}
14311433
} else {
14321434
None
14331435
}
@@ -1441,7 +1443,13 @@ where
14411443
{
14421444
if self.is_contiguous() {
14431445
self.ensure_unique();
1444-
unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) }
1446+
let offset = offset_from_ptr_to_memory(&self.dim, &self.strides);
1447+
unsafe {
1448+
Some(slice::from_raw_parts_mut(
1449+
self.ptr.offset(offset).as_ptr(),
1450+
self.len(),
1451+
))
1452+
}
14451453
} else {
14461454
None
14471455
}

src/impl_raw_views.rs

+12
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ where
6262
/// [`.offset()`] regardless of the starting point due to past offsets.
6363
///
6464
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
65+
///
66+
/// * Strides must be non-negative.
67+
///
68+
/// This function can use debug assertions to check some of these requirements,
69+
/// but it's not a complete check.
6570
///
6671
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
6772
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
@@ -73,6 +78,7 @@ where
7378
if cfg!(debug_assertions) {
7479
assert!(!ptr.is_null(), "The pointer must be non-null.");
7580
if let Strides::Custom(strides) = &shape.strides {
81+
dimension::strides_non_negative(strides).unwrap();
7682
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
7783
} else {
7884
dimension::size_of_shape_checked(&dim).unwrap();
@@ -202,6 +208,11 @@ where
202208
/// [`.offset()`] regardless of the starting point due to past offsets.
203209
///
204210
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
211+
///
212+
/// * Strides must be non-negative.
213+
///
214+
/// This function can use debug assertions to check some of these requirements,
215+
/// but it's not a complete check.
205216
///
206217
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
207218
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
@@ -213,6 +224,7 @@ where
213224
if cfg!(debug_assertions) {
214225
assert!(!ptr.is_null(), "The pointer must be non-null.");
215226
if let Strides::Custom(strides) = &shape.strides {
227+
dimension::strides_non_negative(strides).unwrap();
216228
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
217229
} else {
218230
dimension::size_of_shape_checked(&dim).unwrap();

0 commit comments

Comments
 (0)