Skip to content

Commit e797a77

Browse files
committed
Move upcast broadcast helper to dimension mod
This will allow upcast to be reused in other functions, such as the upcoming ArrayView::broadcast_ref.
1 parent 0740695 commit e797a77

File tree

2 files changed

+52
-51
lines changed

2 files changed

+52
-51
lines changed

src/dimension/broadcast.rs

+51
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,57 @@ where
3434
Ok(out)
3535
}
3636

37+
/// Return new stride when trying to grow `from` into shape `to`
38+
///
39+
/// Broadcasting works by returning a "fake stride" where elements
40+
/// to repeat are in axes with 0 stride, so that several indexes point
41+
/// to the same element.
42+
///
43+
/// **Note:** Cannot be used for mutable iterators, since repeating
44+
/// elements would create aliasing pointers.
45+
pub(crate) fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
46+
// Make sure the product of non-zero axis lengths does not exceed
47+
// `isize::MAX`. This is the only safety check we need to perform
48+
// because all the other constraints of `ArrayBase` are guaranteed
49+
// to be met since we're starting from a valid `ArrayBase`.
50+
let _ = size_of_shape_checked(to).ok()?;
51+
52+
let mut new_stride = to.clone();
53+
// begin at the back (the least significant dimension)
54+
// size of the axis has to either agree or `from` has to be 1
55+
if to.ndim() < from.ndim() {
56+
return None;
57+
}
58+
59+
{
60+
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
61+
for ((er, es), dr) in from
62+
.slice()
63+
.iter()
64+
.rev()
65+
.zip(stride.slice().iter().rev())
66+
.zip(new_stride_iter.by_ref())
67+
{
68+
/* update strides */
69+
if *dr == *er {
70+
/* keep stride */
71+
*dr = *es;
72+
} else if *er == 1 {
73+
/* dead dimension, zero stride */
74+
*dr = 0
75+
} else {
76+
return None;
77+
}
78+
}
79+
80+
/* set remaining strides to zero */
81+
for dr in new_stride_iter {
82+
*dr = 0;
83+
}
84+
}
85+
Some(new_stride)
86+
}
87+
3788
pub trait DimMax<Other: Dimension> {
3889
/// The resulting dimension type after broadcasting.
3990
type Output: Dimension;

src/impl_methods.rs

+1-51
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::dimension::{
2222
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
2323
offset_from_low_addr_ptr_to_logical_ptr, size_of_shape_checked, stride_offset, Axes,
2424
};
25-
use crate::dimension::broadcast::co_broadcast;
25+
use crate::dimension::broadcast::{co_broadcast, upcast};
2626
use crate::dimension::reshape_dim;
2727
use crate::error::{self, ErrorKind, ShapeError, from_kind};
2828
use crate::math_cell::MathCell;
@@ -2036,56 +2036,6 @@ where
20362036
E: IntoDimension,
20372037
S: Data,
20382038
{
2039-
/// Return new stride when trying to grow `from` into shape `to`
2040-
///
2041-
/// Broadcasting works by returning a "fake stride" where elements
2042-
/// to repeat are in axes with 0 stride, so that several indexes point
2043-
/// to the same element.
2044-
///
2045-
/// **Note:** Cannot be used for mutable iterators, since repeating
2046-
/// elements would create aliasing pointers.
2047-
fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
2048-
// Make sure the product of non-zero axis lengths does not exceed
2049-
// `isize::MAX`. This is the only safety check we need to perform
2050-
// because all the other constraints of `ArrayBase` are guaranteed
2051-
// to be met since we're starting from a valid `ArrayBase`.
2052-
let _ = size_of_shape_checked(to).ok()?;
2053-
2054-
let mut new_stride = to.clone();
2055-
// begin at the back (the least significant dimension)
2056-
// size of the axis has to either agree or `from` has to be 1
2057-
if to.ndim() < from.ndim() {
2058-
return None;
2059-
}
2060-
2061-
{
2062-
let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
2063-
for ((er, es), dr) in from
2064-
.slice()
2065-
.iter()
2066-
.rev()
2067-
.zip(stride.slice().iter().rev())
2068-
.zip(new_stride_iter.by_ref())
2069-
{
2070-
/* update strides */
2071-
if *dr == *er {
2072-
/* keep stride */
2073-
*dr = *es;
2074-
} else if *er == 1 {
2075-
/* dead dimension, zero stride */
2076-
*dr = 0
2077-
} else {
2078-
return None;
2079-
}
2080-
}
2081-
2082-
/* set remaining strides to zero */
2083-
for dr in new_stride_iter {
2084-
*dr = 0;
2085-
}
2086-
}
2087-
Some(new_stride)
2088-
}
20892039
let dim = dim.into_dimension();
20902040

20912041
// Note: zero strides are safe precisely because we return an read-only view

0 commit comments

Comments
 (0)