Skip to content

Commit e6a4f10

Browse files
committed
API: Add requirement of non-negative strides to raw ptr constructors
For raw views and array views, require non-negative strides and check this with a debug assertion.
1 parent 0d8b965 commit e6a4f10

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed

src/dimension/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,19 @@ pub fn stride_offset_checked(dim: &[Ix], strides: &[Ix], index: &[Ix]) -> Option
280280
Some(offset)
281281
}
282282

283+
/// Checks if strides are non-negative.
284+
pub fn strides_non_negative<D>(strides: &D) -> Result<(), ShapeError>
285+
where
286+
D: Dimension,
287+
{
288+
for &stride in strides.slice() {
289+
if (stride as isize) < 0 {
290+
return Err(from_kind(ErrorKind::Unsupported));
291+
}
292+
}
293+
Ok(())
294+
}
295+
283296
/// Implementation-specific extensions to `Dimension`
284297
pub trait DimensionExt {
285298
// note: many extensions go in the main trait if they need to be special-

src/impl_raw_views.rs

+12
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ where
6262
/// [`.offset()`] regardless of the starting point due to past offsets.
6363
///
6464
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
65+
///
66+
/// * Strides must be non-negative.
67+
///
68+
/// This function can use debug assertions to check some of these requirements,
69+
/// but it's not a complete check.
6570
///
6671
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
6772
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
@@ -73,6 +78,7 @@ where
7378
if cfg!(debug_assertions) {
7479
assert!(!ptr.is_null(), "The pointer must be non-null.");
7580
if let Strides::Custom(strides) = &shape.strides {
81+
dimension::strides_non_negative(strides).unwrap();
7682
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
7783
} else {
7884
dimension::size_of_shape_checked(&dim).unwrap();
@@ -202,6 +208,11 @@ where
202208
/// [`.offset()`] regardless of the starting point due to past offsets.
203209
///
204210
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
211+
///
212+
/// * Strides must be non-negative.
213+
///
214+
/// This function can use debug assertions to check some of these requirements,
215+
/// but it's not a complete check.
205216
///
206217
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
207218
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
@@ -213,6 +224,7 @@ where
213224
if cfg!(debug_assertions) {
214225
assert!(!ptr.is_null(), "The pointer must be non-null.");
215226
if let Strides::Custom(strides) = &shape.strides {
227+
dimension::strides_non_negative(strides).unwrap();
216228
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
217229
} else {
218230
dimension::size_of_shape_checked(&dim).unwrap();

src/impl_views/constructors.rs

+10
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ where
9696
///
9797
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
9898
///
99+
/// * Strides must be non-negative.
100+
///
101+
/// This function can use debug assertions to check some of these requirements,
102+
/// but it's not a complete check.
103+
///
99104
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
100105
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *const A) -> Self
101106
where
@@ -188,6 +193,11 @@ where
188193
///
189194
/// * The product of non-zero axis lengths must not exceed `isize::MAX`.
190195
///
196+
/// * Strides must be non-negative.
197+
///
198+
/// This function can use debug assertions to check some of these requirements,
199+
/// but it's not a complete check.
200+
///
191201
/// [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset
192202
pub unsafe fn from_shape_ptr<Sh>(shape: Sh, ptr: *mut A) -> Self
193203
where

tests/raw_views.rs

+15
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,18 @@ fn raw_view_deref_into_view_misaligned() {
8181
let data: [u16; 2] = [0x0011, 0x2233];
8282
misaligned_deref(&data);
8383
}
84+
85+
#[test]
86+
#[cfg(debug_assertions)]
87+
#[should_panic = "Unsupported"]
88+
fn raw_view_negative_strides() {
89+
fn misaligned_deref(data: &[u16; 2]) -> ArrayView1<'_, u16> {
90+
let ptr: *const u16 = data.as_ptr();
91+
unsafe {
92+
let raw_view = RawArrayView::from_shape_ptr(1.strides((-1isize) as usize), ptr);
93+
raw_view.deref_into_view()
94+
}
95+
}
96+
let data: [u16; 2] = [0x0011, 0x2233];
97+
misaligned_deref(&data);
98+
}

0 commit comments

Comments
 (0)