@@ -13,22 +13,35 @@ pub struct Shape<D> {
13
13
}
14
14
15
15
#[ derive( Copy , Clone , Debug ) ]
16
- pub ( crate ) enum Contiguous { }
16
+ pub ( crate ) enum Contiguous { }
17
17
18
18
impl < D > Shape < D > {
19
19
pub ( crate ) fn is_c ( & self ) -> bool {
20
20
matches ! ( self . strides, Strides :: C )
21
21
}
22
22
}
23
23
24
-
25
24
/// An array shape of n dimensions in c-order, f-order or custom strides.
26
25
#[ derive( Copy , Clone , Debug ) ]
27
26
pub struct StrideShape < D > {
28
27
pub ( crate ) dim : D ,
29
28
pub ( crate ) strides : Strides < D > ,
30
29
}
31
30
31
+ impl < D > StrideShape < D >
32
+ where
33
+ D : Dimension ,
34
+ {
35
+ /// Return a reference to the dimension
36
+ pub fn raw_dim ( & self ) -> & D {
37
+ & self . dim
38
+ }
39
+ /// Return the size of the shape in number of elements
40
+ pub fn size ( & self ) -> usize {
41
+ self . dim . size ( )
42
+ }
43
+ }
44
+
32
45
/// Stride description
33
46
#[ derive( Copy , Clone , Debug ) ]
34
47
pub ( crate ) enum Strides < D > {
@@ -37,21 +50,26 @@ pub(crate) enum Strides<D> {
37
50
/// Column-major ("F"-order)
38
51
F ,
39
52
/// Custom strides
40
- Custom ( D )
53
+ Custom ( D ) ,
41
54
}
42
55
43
56
impl < D > Strides < D > {
44
57
/// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
45
58
pub ( crate ) fn strides_for_dim ( self , dim : & D ) -> D
46
- where D : Dimension
59
+ where
60
+ D : Dimension ,
47
61
{
48
62
match self {
49
63
Strides :: C => dim. default_strides ( ) ,
50
64
Strides :: F => dim. fortran_strides ( ) ,
51
65
Strides :: Custom ( c) => {
52
- debug_assert_eq ! ( c. ndim( ) , dim. ndim( ) ,
66
+ debug_assert_eq ! (
67
+ c. ndim( ) ,
68
+ dim. ndim( ) ,
53
69
"Custom strides given with {} dimensions, expected {}" ,
54
- c. ndim( ) , dim. ndim( ) ) ;
70
+ c. ndim( ) ,
71
+ dim. ndim( )
72
+ ) ;
55
73
c
56
74
}
57
75
}
@@ -94,11 +112,7 @@ where
94
112
{
95
113
fn from ( value : T ) -> Self {
96
114
let shape = value. into_shape ( ) ;
97
- let st = if shape. is_c ( ) {
98
- Strides :: C
99
- } else {
100
- Strides :: F
101
- } ;
115
+ let st = if shape. is_c ( ) { Strides :: C } else { Strides :: F } ;
102
116
StrideShape {
103
117
strides : st,
104
118
dim : shape. dim ,
@@ -161,8 +175,10 @@ impl<D> Shape<D>
161
175
where
162
176
D : Dimension ,
163
177
{
164
- // Return a reference to the dimension
165
- //pub fn dimension(&self) -> &D { &self.dim }
178
+ /// Return a reference to the dimension
179
+ pub fn raw_dim ( & self ) -> & D {
180
+ & self . dim
181
+ }
166
182
/// Return the size of the shape in number of elements
167
183
pub fn size ( & self ) -> usize {
168
184
self . dim . size ( )
0 commit comments