Skip to content

Commit ac55d74

Browse files
committed
FEAT: Add methods RawArrayView/Mut::cast
Add methods for raw view casts, this makes it easier to build upon the raw array view functionality. Of course, adding more capabilities for the unsafe APIs opens for users to make mistakes, but we need these capabilities ourselves for developing features for ndarray. The reason for this to exist is for the `ArrayViewMut<T>` → `ArrayView<Cell<T>>` transformation demonstrated in the tests. However, we don't add a method for this directly, at least not yet. The reason is that we'd like to have a type that is exactly like `Cell<T>` but also supports the arithmetic ops.
1 parent 5447af8 commit ac55d74

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

src/impl_raw_views.rs

+53
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::mem;
12
use std::ptr::NonNull;
23

34
use crate::dimension::{self, stride_offset};
@@ -111,6 +112,32 @@ where
111112

112113
(left, right)
113114
}
115+
116+
/// Cast the raw pointer of the raw array view to a different type
117+
///
118+
/// **Panics** if element size is not compatible.
119+
///
120+
/// Lack of panic does not imply it is a valid cast. The cast works the same
121+
/// way as regular raw pointer casts.
122+
///
123+
/// While this method is safe, for the same reason as regular raw pointer
124+
/// casts are safe, access through the produced raw view is only possible
125+
/// in an unsafe block or function.
126+
pub fn cast<B>(self) -> RawArrayView<B, D> {
127+
assert_eq!(
128+
mem::size_of::<B>(),
129+
mem::size_of::<A>(),
130+
"size mismatch in raw view cast"
131+
);
132+
let ptr = self.ptr.cast::<B>();
133+
debug_assert!(
134+
is_aligned(ptr.as_ptr()),
135+
"alignment mismatch in raw view cast"
136+
);
137+
/* Alignment checked with debug assertion: alignment could be dynamically correct,
138+
* and we don't have a check that compiles out for that. */
139+
unsafe { RawArrayView::new(ptr, self.dim, self.strides) }
140+
}
114141
}
115142

116143
impl<A, D> RawArrayViewMut<A, D>
@@ -222,4 +249,30 @@ where
222249
)
223250
}
224251
}
252+
253+
/// Cast the raw pointer of the raw array view to a different type
254+
///
255+
/// **Panics** if element size is not compatible.
256+
///
257+
/// Lack of panic does not imply it is a valid cast. The cast works the same
258+
/// way as regular raw pointer casts.
259+
///
260+
/// While this method is safe, for the same reason as regular raw pointer
261+
/// casts are safe, access through the produced raw view is only possible
262+
/// in an unsafe block or function.
263+
pub fn cast<B>(self) -> RawArrayViewMut<B, D> {
264+
assert_eq!(
265+
mem::size_of::<B>(),
266+
mem::size_of::<A>(),
267+
"size mismatch in raw view cast"
268+
);
269+
let ptr = self.ptr.cast::<B>();
270+
debug_assert!(
271+
is_aligned(ptr.as_ptr()),
272+
"alignment mismatch in raw view cast"
273+
);
274+
/* Alignment checked with debug assertion: alignment could be dynamically correct,
275+
* and we don't have a check that compiles out for that. */
276+
unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) }
277+
}
225278
}

tests/raw_views.rs

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
use ndarray::prelude::*;
2+
use ndarray::Zip;
3+
4+
use std::cell::Cell;
5+
#[cfg(debug_assertions)]
6+
use std::mem;
7+
8+
#[test]
9+
fn raw_view_cast_cell() {
10+
// Test .cast() by creating an ArrayView<Cell<f32>>
11+
12+
let mut a = Array::from_shape_fn((10, 5), |(i, j)| (i * j) as f32);
13+
let answer = &a + 1.;
14+
15+
{
16+
let raw_cell_view = a.raw_view_mut().cast::<Cell<f32>>();
17+
let cell_view = unsafe { raw_cell_view.deref_into_view() };
18+
19+
Zip::from(cell_view).apply(|elt| elt.set(elt.get() + 1.));
20+
}
21+
assert_eq!(a, answer);
22+
}
23+
24+
#[test]
25+
fn raw_view_cast_reinterpret() {
26+
// Test .cast() by reinterpreting u16 as [u8; 2]
27+
let a = Array::from_shape_fn((5, 5).f(), |(i, j)| (i as u16) << 8 | j as u16);
28+
let answer = a.mapv(u16::to_ne_bytes);
29+
30+
let raw_view = a.raw_view().cast::<[u8; 2]>();
31+
let view = unsafe { raw_view.deref_into_view() };
32+
assert_eq!(view, answer);
33+
}
34+
35+
#[test]
36+
fn raw_view_cast_zst() {
37+
struct Zst;
38+
39+
let a = Array::<(), _>::default((250, 250));
40+
let b: RawArrayView<Zst, _> = a.raw_view().cast::<Zst>();
41+
assert_eq!(a.shape(), b.shape());
42+
assert_eq!(a.as_ptr() as *const u8, b.as_ptr() as *const u8);
43+
}
44+
45+
#[test]
46+
#[should_panic]
47+
fn raw_view_invalid_size_cast() {
48+
let data = [0i32; 16];
49+
ArrayView::from(&data[..]).raw_view().cast::<i64>();
50+
}
51+
52+
#[test]
53+
#[should_panic]
54+
fn raw_view_mut_invalid_size_cast() {
55+
let mut data = [0i32; 16];
56+
ArrayViewMut::from(&mut data[..])
57+
.raw_view_mut()
58+
.cast::<i64>();
59+
}
60+
61+
#[test]
62+
#[cfg(debug_assertions)]
63+
#[should_panic = "alignment mismatch"]
64+
fn raw_view_invalid_align_cast() {
65+
#[derive(Copy, Clone, Debug)]
66+
#[repr(transparent)]
67+
struct A([u8; 16]);
68+
#[derive(Copy, Clone, Debug)]
69+
#[repr(transparent)]
70+
struct B([f64; 2]);
71+
72+
unsafe {
73+
const LEN: usize = 16;
74+
let mut buffer = [0u8; mem::size_of::<A>() * (LEN + 1)];
75+
// Take out a slice of buffer as &[A] which is misaligned for B
76+
let mut ptr = buffer.as_mut_ptr();
77+
if ptr as usize % mem::align_of::<B>() == 0 {
78+
ptr = ptr.add(1);
79+
}
80+
81+
let view = RawArrayViewMut::from_shape_ptr(LEN, ptr as *mut A);
82+
83+
// misaligned cast - test debug assertion
84+
view.cast::<B>();
85+
}
86+
}

0 commit comments

Comments
 (0)