Skip to content

Commit 4e31d2f

Browse files
authored
Merge pull request #978 from stokhos/getter_dim_in_shape
Add getter for dim in Shape and StrideShape
2 parents 47626e4 + bed9c63 commit 4e31d2f

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

src/shape_builder.rs

+29-13
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,35 @@ pub struct Shape<D> {
1313
}
1414

1515
#[derive(Copy, Clone, Debug)]
16-
pub(crate) enum Contiguous { }
16+
pub(crate) enum Contiguous {}
1717

1818
impl<D> Shape<D> {
1919
pub(crate) fn is_c(&self) -> bool {
2020
matches!(self.strides, Strides::C)
2121
}
2222
}
2323

24-
2524
/// An array shape of n dimensions in c-order, f-order or custom strides.
2625
#[derive(Copy, Clone, Debug)]
2726
pub struct StrideShape<D> {
2827
pub(crate) dim: D,
2928
pub(crate) strides: Strides<D>,
3029
}
3130

31+
impl<D> StrideShape<D>
32+
where
33+
D: Dimension,
34+
{
35+
/// Return a reference to the dimension
36+
pub fn raw_dim(&self) -> &D {
37+
&self.dim
38+
}
39+
/// Return the size of the shape in number of elements
40+
pub fn size(&self) -> usize {
41+
self.dim.size()
42+
}
43+
}
44+
3245
/// Stride description
3346
#[derive(Copy, Clone, Debug)]
3447
pub(crate) enum Strides<D> {
@@ -37,21 +50,26 @@ pub(crate) enum Strides<D> {
3750
/// Column-major ("F"-order)
3851
F,
3952
/// Custom strides
40-
Custom(D)
53+
Custom(D),
4154
}
4255

4356
impl<D> Strides<D> {
4457
/// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
4558
pub(crate) fn strides_for_dim(self, dim: &D) -> D
46-
where D: Dimension
59+
where
60+
D: Dimension,
4761
{
4862
match self {
4963
Strides::C => dim.default_strides(),
5064
Strides::F => dim.fortran_strides(),
5165
Strides::Custom(c) => {
52-
debug_assert_eq!(c.ndim(), dim.ndim(),
66+
debug_assert_eq!(
67+
c.ndim(),
68+
dim.ndim(),
5369
"Custom strides given with {} dimensions, expected {}",
54-
c.ndim(), dim.ndim());
70+
c.ndim(),
71+
dim.ndim()
72+
);
5573
c
5674
}
5775
}
@@ -94,11 +112,7 @@ where
94112
{
95113
fn from(value: T) -> Self {
96114
let shape = value.into_shape();
97-
let st = if shape.is_c() {
98-
Strides::C
99-
} else {
100-
Strides::F
101-
};
115+
let st = if shape.is_c() { Strides::C } else { Strides::F };
102116
StrideShape {
103117
strides: st,
104118
dim: shape.dim,
@@ -161,8 +175,10 @@ impl<D> Shape<D>
161175
where
162176
D: Dimension,
163177
{
164-
// Return a reference to the dimension
165-
//pub fn dimension(&self) -> &D { &self.dim }
178+
/// Return a reference to the dimension
179+
pub fn raw_dim(&self) -> &D {
180+
&self.dim
181+
}
166182
/// Return the size of the shape in number of elements
167183
pub fn size(&self) -> usize {
168184
self.dim.size()

0 commit comments

Comments
 (0)