Skip to content

Commit ecb7643

Browse files
authored
Merge pull request #734 from rust-ndarray/raw-view-cast
Simplify ArrayView construction from NonNull<T> and add RawView .cast() method
2 parents ad3340a + ac55d74 commit ecb7643

File tree

7 files changed

+199
-31
lines changed

7 files changed

+199
-31
lines changed

src/impl_methods.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ where
139139
S: Data,
140140
{
141141
debug_assert!(self.pointer_is_inbounds());
142-
unsafe { ArrayView::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
142+
unsafe { ArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) }
143143
}
144144

145145
/// Return a read-write view of the array
@@ -148,7 +148,7 @@ where
148148
S: DataMut,
149149
{
150150
self.ensure_unique();
151-
unsafe { ArrayViewMut::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
151+
unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
152152
}
153153

154154
/// Return an uniquely owned copy of the array.
@@ -1313,7 +1313,7 @@ where
13131313
/// Return a raw view of the array.
13141314
#[inline]
13151315
pub fn raw_view(&self) -> RawArrayView<A, D> {
1316-
unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
1316+
unsafe { RawArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) }
13171317
}
13181318

13191319
/// Return a raw mutable view of the array.
@@ -1323,7 +1323,7 @@ where
13231323
S: RawDataMut,
13241324
{
13251325
self.try_ensure_unique(); // for RcArray
1326-
unsafe { RawArrayViewMut::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
1326+
unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
13271327
}
13281328

13291329
/// Return the array’s data as a slice, if it is contiguous and in standard order.
@@ -1620,7 +1620,7 @@ where
16201620
Some(st) => st,
16211621
None => return None,
16221622
};
1623-
unsafe { Some(ArrayView::new_(self.ptr.as_ptr(), dim, broadcast_strides)) }
1623+
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
16241624
}
16251625

16261626
/// Swap axes `ax` and `bx`.

src/impl_raw_views.rs

+75-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::mem;
2+
use std::ptr::NonNull;
3+
14
use crate::dimension::{self, stride_offset};
25
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
36
use crate::imp_prelude::*;
@@ -11,16 +14,20 @@ where
1114
///
1215
/// Unsafe because caller is responsible for ensuring that the array will
1316
/// meet all of the invariants of the `ArrayBase` type.
14-
#[inline(always)]
15-
pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
17+
#[inline]
18+
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
1619
RawArrayView {
1720
data: RawViewRepr::new(),
18-
ptr: nonnull_debug_checked_from_ptr(ptr as *mut _),
21+
ptr,
1922
dim,
2023
strides,
2124
}
2225
}
2326

27+
unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
28+
Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides)
29+
}
30+
2431
/// Create an `RawArrayView<A, D>` from shape information and a raw pointer
2532
/// to the elements.
2633
///
@@ -76,7 +83,7 @@ where
7683
/// ensure that all of the data is valid and choose the correct lifetime.
7784
#[inline]
7885
pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> {
79-
ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides)
86+
ArrayView::new(self.ptr, self.dim, self.strides)
8087
}
8188

8289
/// Split the array view along `axis` and return one array pointer strictly
@@ -105,6 +112,32 @@ where
105112

106113
(left, right)
107114
}
115+
116+
/// Cast the raw pointer of the raw array view to a different type
117+
///
118+
/// **Panics** if element size is not compatible.
119+
///
120+
/// Lack of panic does not imply it is a valid cast. The cast works the same
121+
/// way as regular raw pointer casts.
122+
///
123+
/// While this method is safe, for the same reason as regular raw pointer
124+
/// casts are safe, access through the produced raw view is only possible
125+
/// in an unsafe block or function.
126+
pub fn cast<B>(self) -> RawArrayView<B, D> {
127+
assert_eq!(
128+
mem::size_of::<B>(),
129+
mem::size_of::<A>(),
130+
"size mismatch in raw view cast"
131+
);
132+
let ptr = self.ptr.cast::<B>();
133+
debug_assert!(
134+
is_aligned(ptr.as_ptr()),
135+
"alignment mismatch in raw view cast"
136+
);
137+
/* Alignment checked with debug assertion: alignment could be dynamically correct,
138+
* and we don't have a check that compiles out for that. */
139+
unsafe { RawArrayView::new(ptr, self.dim, self.strides) }
140+
}
108141
}
109142

110143
impl<A, D> RawArrayViewMut<A, D>
@@ -115,16 +148,20 @@ where
115148
///
116149
/// Unsafe because caller is responsible for ensuring that the array will
117150
/// meet all of the invariants of the `ArrayBase` type.
118-
#[inline(always)]
119-
pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
151+
#[inline]
152+
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
120153
RawArrayViewMut {
121154
data: RawViewRepr::new(),
122-
ptr: nonnull_debug_checked_from_ptr(ptr),
155+
ptr,
123156
dim,
124157
strides,
125158
}
126159
}
127160

161+
unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
162+
Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides)
163+
}
164+
128165
/// Create an `RawArrayViewMut<A, D>` from shape information and a raw
129166
/// pointer to the elements.
130167
///
@@ -176,7 +213,7 @@ where
176213
/// Converts to a non-mutable `RawArrayView`.
177214
#[inline]
178215
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
179-
unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) }
216+
unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) }
180217
}
181218

182219
/// Converts to a read-only view of the array.
@@ -186,7 +223,7 @@ where
186223
/// ensure that all of the data is valid and choose the correct lifetime.
187224
#[inline]
188225
pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> {
189-
ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides)
226+
ArrayView::new(self.ptr, self.dim, self.strides)
190227
}
191228

192229
/// Converts to a mutable view of the array.
@@ -196,7 +233,7 @@ where
196233
/// ensure that all of the data is valid and choose the correct lifetime.
197234
#[inline]
198235
pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> {
199-
ArrayViewMut::new_(self.ptr.as_ptr(), self.dim, self.strides)
236+
ArrayViewMut::new(self.ptr, self.dim, self.strides)
200237
}
201238

202239
/// Split the array view along `axis` and return one array pointer strictly
@@ -207,9 +244,35 @@ where
207244
let (left, right) = self.into_raw_view().split_at(axis, index);
208245
unsafe {
209246
(
210-
Self::new_(left.ptr.as_ptr(), left.dim, left.strides),
211-
Self::new_(right.ptr.as_ptr(), right.dim, right.strides),
247+
Self::new(left.ptr, left.dim, left.strides),
248+
Self::new(right.ptr, right.dim, right.strides),
212249
)
213250
}
214251
}
252+
253+
/// Cast the raw pointer of the raw array view to a different type
254+
///
255+
/// **Panics** if element size is not compatible.
256+
///
257+
/// Lack of panic does not imply it is a valid cast. The cast works the same
258+
/// way as regular raw pointer casts.
259+
///
260+
/// While this method is safe, for the same reason as regular raw pointer
261+
/// casts are safe, access through the produced raw view is only possible
262+
/// in an unsafe block or function.
263+
pub fn cast<B>(self) -> RawArrayViewMut<B, D> {
264+
assert_eq!(
265+
mem::size_of::<B>(),
266+
mem::size_of::<A>(),
267+
"size mismatch in raw view cast"
268+
);
269+
let ptr = self.ptr.cast::<B>();
270+
debug_assert!(
271+
is_aligned(ptr.as_ptr()),
272+
"alignment mismatch in raw view cast"
273+
);
274+
/* Alignment checked with debug assertion: alignment could be dynamically correct,
275+
* and we don't have a check that compiles out for that. */
276+
unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) }
277+
}
215278
}

src/impl_views/constructors.rs

+27-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88

9+
use std::ptr::NonNull;
10+
911
use crate::dimension;
1012
use crate::error::ShapeError;
1113
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
@@ -200,11 +202,11 @@ where
200202

201203
/// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime
202204
/// outlived by `'a'`.
203-
pub fn reborrow<'b>(mut self) -> ArrayViewMut<'b, A, D>
205+
pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D>
204206
where
205207
'a: 'b,
206208
{
207-
unsafe { ArrayViewMut::new_(self.as_mut_ptr(), self.dim, self.strides) }
209+
unsafe { ArrayViewMut::new(self.ptr, self.dim, self.strides) }
208210
}
209211
}
210212

@@ -217,14 +219,24 @@ where
217219
///
218220
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
219221
#[inline(always)]
220-
pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
222+
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
223+
if cfg!(debug_assertions) {
224+
assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned.");
225+
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
226+
}
221227
ArrayView {
222228
data: ViewRepr::new(),
223-
ptr: nonnull_debug_checked_from_ptr(ptr as *mut A),
229+
ptr,
224230
dim,
225231
strides,
226232
}
227233
}
234+
235+
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
236+
#[inline]
237+
pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
238+
Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides)
239+
}
228240
}
229241

230242
impl<'a, A, D> ArrayViewMut<'a, A, D>
@@ -235,17 +247,24 @@ where
235247
///
236248
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
237249
#[inline(always)]
238-
pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
250+
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
239251
if cfg!(debug_assertions) {
240-
assert!(!ptr.is_null(), "The pointer must be non-null.");
241-
assert!(is_aligned(ptr), "The pointer must be aligned.");
252+
assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned.");
242253
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
243254
}
244255
ArrayViewMut {
245256
data: ViewRepr::new(),
246-
ptr: nonnull_debug_checked_from_ptr(ptr),
257+
ptr,
247258
dim,
248259
strides,
249260
}
250261
}
262+
263+
/// Create a new `ArrayView`
264+
///
265+
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
266+
#[inline(always)]
267+
pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
268+
Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides)
269+
}
251270
}

src/impl_views/conversions.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ where
2626
where
2727
'a: 'b,
2828
{
29-
unsafe { ArrayView::new_(self.as_ptr(), self.dim, self.strides) }
29+
unsafe { ArrayView::new(self.ptr, self.dim, self.strides) }
3030
}
3131

3232
/// Return the array’s data as a slice, if it is contiguous and in standard order.
@@ -53,7 +53,7 @@ where
5353

5454
/// Converts to a raw array view.
5555
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
56-
unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) }
56+
unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) }
5757
}
5858
}
5959

@@ -161,12 +161,12 @@ where
161161
{
162162
// Convert into a read-only view
163163
pub(crate) fn into_view(self) -> ArrayView<'a, A, D> {
164-
unsafe { ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) }
164+
unsafe { ArrayView::new(self.ptr, self.dim, self.strides) }
165165
}
166166

167167
/// Converts to a mutable raw array view.
168168
pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut<A, D> {
169-
unsafe { RawArrayViewMut::new_(self.ptr.as_ptr(), self.dim, self.strides) }
169+
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
170170
}
171171

172172
#[inline]

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ where
15171517
let ptr = self.ptr;
15181518
let mut strides = dim.clone();
15191519
strides.slice_mut().copy_from_slice(self.strides.slice());
1520-
unsafe { ArrayView::new_(ptr.as_ptr(), dim, strides) }
1520+
unsafe { ArrayView::new(ptr, dim, strides) }
15211521
}
15221522

15231523
fn raw_strides(&self) -> D {

src/zip/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ where
7373
type Output = ArrayView<'a, A, E::Dim>;
7474
fn broadcast_unwrap(self, shape: E) -> Self::Output {
7575
let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
76-
unsafe { ArrayView::new_(res.ptr.as_ptr(), res.dim, res.strides) }
76+
unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
7777
}
7878
private_impl! {}
7979
}

0 commit comments

Comments
 (0)