2
2
// Copyright (C) 2015-2016 Lumol's contributors — BSD license
3
3
4
4
//! Multi-dimensional arrays based on ndarray
5
- use ndarray:: { Array , Ix } ;
5
+ use ndarray;
6
6
7
- use std:: ops:: { Index , IndexMut } ;
7
+ use std:: ops:: { Index , IndexMut , Deref , DerefMut } ;
8
8
use types:: Zero ;
9
9
10
10
/// Two dimensional tensors, based on ndarray.
@@ -23,21 +23,7 @@ use types::Zero;
23
23
/// assert_eq!(a[(0, 4)], 7.0);
24
24
/// ```
25
25
#[ derive( Debug , Clone , PartialEq ) ]
26
- pub struct Array2 < T > ( Array < T , ( Ix , Ix ) > ) ;
27
-
28
- impl < T > Array2 < T > {
29
- /// Get the shape of the array
30
- /// # Examples
31
- /// ```
32
- /// # use lumol::types::Array2;
33
- /// let a: Array2<f64> = Array2::zeros((3, 5));
34
- /// assert_eq!(a.shape(), (3, 5));
35
- /// ```
36
- pub fn shape ( & self ) -> ( Ix , Ix ) {
37
- let shape = self . 0 . shape ( ) ;
38
- ( shape[ 0 ] , shape[ 1 ] )
39
- }
40
- }
26
+ pub struct Array2 < T > ( ndarray:: Array2 < T > ) ;
41
27
42
28
impl < T : Zero + Clone > Array2 < T > {
43
29
/// Create a new `Array2` of the specified `size` filled with the
@@ -50,8 +36,8 @@ impl<T: Zero + Clone> Array2<T> {
50
36
/// let a: Array2<f64> = Array2::zeros((8, 5));
51
37
/// assert_eq!(a[(6, 2)], 0.0);
52
38
/// ```
53
- pub fn zeros ( size : ( Ix , Ix ) ) -> Array2 < T > {
54
- Array2 ( Array :: < T , ( Ix , Ix ) > :: zeros ( size) )
39
+ pub fn zeros ( size : ( usize , usize ) ) -> Array2 < T > {
40
+ Array2 ( ndarray :: Array2 :: zeros ( size) )
55
41
}
56
42
57
43
/// Resize the array if the current size is not `size`, and fill the
@@ -73,28 +59,13 @@ impl<T: Zero + Clone> Array2<T> {
73
59
/// a.resize_if_different((8, 9));
74
60
/// assert_eq!(a[(3, 3)], 0.0);
75
61
/// ```
76
- pub fn resize_if_different ( & mut self , size : ( Ix , Ix ) ) {
77
- if self . 0 . shape ( ) != & [ size. 0 , size . 1 ] {
62
+ pub fn resize_if_different ( & mut self , size : ( usize , usize ) ) {
63
+ if self . dim ( ) != size {
78
64
* self = Array2 :: zeros ( size) ;
79
65
}
80
66
}
81
67
}
82
68
83
- impl < T : Clone > Array2 < T > {
84
- /// Assign the given scalar to all entries in this array
85
- /// # Examples
86
- /// ```
87
- /// # use lumol::types::Array2;
88
- /// let mut a = Array2::zeros((8, 5));
89
- /// a.assign(33.0);
90
- ///
91
- /// assert_eq!(a[(3, 4)], 33.0);
92
- /// ```
93
- pub fn assign ( & mut self , value : T ) {
94
- self . 0 . assign_scalar ( & value) ;
95
- }
96
- }
97
-
98
69
impl < T : Default > Array2 < T > {
99
70
/// Create a new `Array2` of the specified `size` filled with the
100
71
/// `Default::default` return value.
@@ -108,30 +79,44 @@ impl<T: Default> Array2<T> {
108
79
///
109
80
/// assert_eq!(a, b);
110
81
/// ```
111
- pub fn default ( size : ( Ix , Ix ) ) -> Array2 < T > {
112
- Array2 ( Array :: < T , ( Ix , Ix ) > :: default ( size) )
82
+ pub fn default ( size : ( usize , usize ) ) -> Array2 < T > {
83
+ Array2 ( ndarray :: Array2 :: default ( size) )
113
84
}
114
85
}
115
86
116
- impl < T > Index < ( Ix , Ix ) > for Array2 < T > {
87
+ impl < T > Index < ( usize , usize ) > for Array2 < T > {
117
88
type Output = T ;
118
- fn index ( & self , index : ( Ix , Ix ) ) -> & T {
89
+ fn index ( & self , index : ( usize , usize ) ) -> & T {
119
90
unsafe {
120
91
// ndarray does the check for us in debug builds
121
92
self . 0 . uget ( index)
122
93
}
123
94
}
124
95
}
125
96
126
- impl < T > IndexMut < ( Ix , Ix ) > for Array2 < T > {
127
- fn index_mut ( & mut self , index : ( Ix , Ix ) ) -> & mut T {
97
+ impl < T > IndexMut < ( usize , usize ) > for Array2 < T > {
98
+ fn index_mut ( & mut self , index : ( usize , usize ) ) -> & mut T {
128
99
unsafe {
129
100
// ndarray does the check for us in debug builds
130
101
self . 0 . uget_mut ( index)
131
102
}
132
103
}
133
104
}
134
105
106
+ impl < T > Deref for Array2 < T > {
107
+ type Target = ndarray:: Array2 < T > ;
108
+
109
+ fn deref ( & self ) -> & ndarray:: Array2 < T > {
110
+ & self . 0
111
+ }
112
+ }
113
+
114
+ impl < T > DerefMut for Array2 < T > {
115
+ fn deref_mut ( & mut self ) -> & mut ndarray:: Array2 < T > {
116
+ & mut self . 0
117
+ }
118
+ }
119
+
135
120
/******************************************************************************/
136
121
137
122
/// Three dimensional tensors, based on ndarray
@@ -150,23 +135,9 @@ impl<T> IndexMut<(Ix, Ix)> for Array2<T> {
150
135
/// assert_eq!(a[(0, 4, 1)], 7.0);
151
136
/// ```
152
137
#[ derive( Debug , Clone , PartialEq ) ]
153
- pub struct Array3 < T > ( Array < T , ( Ix , Ix , Ix ) > ) ;
138
+ pub struct Array3 < T > ( ndarray :: Array3 < T > ) ;
154
139
155
140
impl < T > Array3 < T > {
156
- /// Get the shape of the array.
157
- /// # Examples
158
- /// ```
159
- /// # use lumol::types::Array3;
160
- /// let a: Array3<f64> = Array3::zeros((3, 5, 7));
161
- /// assert_eq!(a.shape(), (3, 5, 7));
162
- /// ```
163
- pub fn shape ( & self ) -> ( Ix , Ix , Ix ) {
164
- let shape = self . 0 . shape ( ) ;
165
- ( shape[ 0 ] , shape[ 1 ] , shape[ 2 ] )
166
- }
167
- }
168
-
169
- impl < T : Zero + Clone > Array3 < T > {
170
141
/// Create a new `Array3` of the specified `size` filled with the
171
142
/// `Zero::zero` return value.
172
143
///
@@ -177,8 +148,8 @@ impl<T: Zero + Clone> Array3<T> {
177
148
/// let a: Array3<f64> = Array3::zeros((8, 5, 2));
178
149
/// assert_eq!(a[(6, 2, 0)], 0.0);
179
150
/// ```
180
- pub fn zeros ( size : ( Ix , Ix , Ix ) ) -> Array3 < T > {
181
- Array3 ( Array :: < T , ( Ix , Ix , Ix ) > :: zeros ( size) )
151
+ pub fn zeros ( size : ( usize , usize , usize ) ) -> Array3 < T > where T : Zero + Clone {
152
+ Array3 ( ndarray :: Array3 :: zeros ( size) )
182
153
}
183
154
184
155
/// Resize the array if the current size is not `size`, and fill the
@@ -200,21 +171,12 @@ impl<T: Zero + Clone> Array3<T> {
200
171
/// a.resize_if_different((8, 5, 6));
201
172
/// assert_eq!(a[(3, 3, 3)], 0.0);
202
173
/// ```
203
- pub fn resize_if_different ( & mut self , size : ( Ix , Ix , Ix ) ) {
174
+ pub fn resize_if_different ( & mut self , size : ( usize , usize , usize ) ) where T : Zero + Clone {
204
175
if self . 0 . shape ( ) != & [ size. 0 , size. 1 , size. 2 ] {
205
176
* self = Array3 :: zeros ( size) ;
206
177
}
207
178
}
208
- }
209
-
210
- impl < T : Clone > Array3 < T > {
211
- /// Assign the given scalar to all entries in this array
212
- pub fn assign ( & mut self , value : T ) {
213
- self . 0 . assign_scalar ( & value) ;
214
- }
215
- }
216
179
217
- impl < T : Default > Array3 < T > {
218
180
/// Create a new `Array3` of the specified `size` filled with the
219
181
/// `Default::default` return value.
220
182
/// `Default::default` return value.
@@ -228,30 +190,44 @@ impl<T: Default> Array3<T> {
228
190
///
229
191
/// assert_eq!(a, b);
230
192
/// ```
231
- pub fn default ( size : ( Ix , Ix , Ix ) ) -> Array3 < T > {
232
- Array3 ( Array :: < T , ( Ix , Ix , Ix ) > :: default ( size) )
193
+ pub fn default ( size : ( usize , usize , usize ) ) -> Array3 < T > where T : Default {
194
+ Array3 ( ndarray :: Array3 :: default ( size) )
233
195
}
234
196
}
235
197
236
- impl < T > Index < ( Ix , Ix , Ix ) > for Array3 < T > {
198
+ impl < T > Index < ( usize , usize , usize ) > for Array3 < T > {
237
199
type Output = T ;
238
- fn index ( & self , index : ( Ix , Ix , Ix ) ) -> & T {
200
+ fn index ( & self , index : ( usize , usize , usize ) ) -> & T {
239
201
unsafe {
240
202
// ndarray does the check for us in debug builds
241
203
self . 0 . uget ( index)
242
204
}
243
205
}
244
206
}
245
207
246
- impl < T > IndexMut < ( Ix , Ix , Ix ) > for Array3 < T > {
247
- fn index_mut ( & mut self , index : ( Ix , Ix , Ix ) ) -> & mut T {
208
+ impl < T > IndexMut < ( usize , usize , usize ) > for Array3 < T > {
209
+ fn index_mut ( & mut self , index : ( usize , usize , usize ) ) -> & mut T {
248
210
unsafe {
249
211
// ndarray does the check for us in debug builds
250
212
self . 0 . uget_mut ( index)
251
213
}
252
214
}
253
215
}
254
216
217
+ impl < T > Deref for Array3 < T > {
218
+ type Target = ndarray:: Array3 < T > ;
219
+
220
+ fn deref ( & self ) -> & ndarray:: Array3 < T > {
221
+ & self . 0
222
+ }
223
+ }
224
+
225
+ impl < T > DerefMut for Array3 < T > {
226
+ fn deref_mut ( & mut self ) -> & mut ndarray:: Array3 < T > {
227
+ & mut self . 0
228
+ }
229
+ }
230
+
255
231
#[ cfg( test) ]
256
232
mod tests {
257
233
mod array2 {
@@ -283,11 +259,11 @@ mod tests {
283
259
#[ test]
284
260
fn resize ( ) {
285
261
let mut a: Array2 < f64 > = Array2 :: zeros ( ( 3 , 4 ) ) ;
286
- assert_eq ! ( a. shape ( ) , ( 3 , 4 ) ) ;
262
+ assert_eq ! ( a. dim ( ) , ( 3 , 4 ) ) ;
287
263
a[ ( 1 , 1 ) ] = 42.0 ;
288
264
289
265
a. resize_if_different ( ( 7 , 90 ) ) ;
290
- assert_eq ! ( a. shape ( ) , ( 7 , 90 ) ) ;
266
+ assert_eq ! ( a. dim ( ) , ( 7 , 90 ) ) ;
291
267
assert_eq ! ( a[ ( 1 , 1 ) ] , 0.0 ) ;
292
268
293
269
a[ ( 1 , 1 ) ] = 42.0 ;
@@ -355,11 +331,11 @@ mod tests {
355
331
#[ test]
356
332
fn resize ( ) {
357
333
let mut a: Array3 < f64 > = Array3 :: zeros ( ( 3 , 4 , 5 ) ) ;
358
- assert_eq ! ( a. shape ( ) , ( 3 , 4 , 5 ) ) ;
334
+ assert_eq ! ( a. dim ( ) , ( 3 , 4 , 5 ) ) ;
359
335
a[ ( 1 , 1 , 1 ) ] = 42.0 ;
360
336
361
337
a. resize_if_different ( ( 7 , 90 , 8 ) ) ;
362
- assert_eq ! ( a. shape ( ) , ( 7 , 90 , 8 ) ) ;
338
+ assert_eq ! ( a. dim ( ) , ( 7 , 90 , 8 ) ) ;
363
339
assert_eq ! ( a[ ( 1 , 1 , 1 ) ] , 0.0 ) ;
364
340
365
341
a[ ( 1 , 1 , 1 ) ] = 42.0 ;
0 commit comments