@@ -197,6 +197,62 @@ fn accurate_mul_f64() {
197
197
}
198
198
199
199
200
+ #[ test]
201
+ fn accurate_mul_with_column_f64 ( ) {
202
+ // pick a few random sizes
203
+ let mut rng = rand:: weak_rng ( ) ;
204
+ for i in 0 ..10 {
205
+ let m = rng. gen_range ( 1 , 350 ) ;
206
+ let k = rng. gen_range ( 1 , 350 ) ;
207
+ let a = gen_f64 ( Ix2 ( m, k) ) ;
208
+ let b_owner = gen_f64 ( Ix2 ( k, k) ) ;
209
+ let b_row_col;
210
+ let b_sq;
211
+
212
+ // pick dense square or broadcasted to square matrix
213
+ match i {
214
+ 0 ... 3 => b_sq = b_owner. view ( ) ,
215
+ 4 ... 7 => {
216
+ b_row_col = b_owner. column ( 0 ) ;
217
+ b_sq = b_row_col. broadcast ( ( k, k) ) . unwrap ( ) ;
218
+ }
219
+ _otherwise => {
220
+ b_row_col = b_owner. row ( 0 ) ;
221
+ b_sq = b_row_col. broadcast ( ( k, k) ) . unwrap ( ) ;
222
+ }
223
+ } ;
224
+
225
+ for j in 0 ..k {
226
+ for & flip in & [ true , false ] {
227
+ let j = j as isize ;
228
+ let b = if flip {
229
+ // one row in 2D
230
+ b_sq. slice ( s ! [ j..j + 1 , ..] ) . reversed_axes ( )
231
+ } else {
232
+ // one column in 2D
233
+ b_sq. slice ( s ! [ .., j..j + 1 ] )
234
+ } ;
235
+ println ! ( "Testing size ({} × {}) by ({} × {})" , a. shape( ) [ 0 ] , a. shape( ) [ 1 ] , b. shape( ) [ 0 ] , b. shape( ) [ 1 ] ) ;
236
+ println ! ( "Strides ({:?}) by ({:?})" , a. strides( ) , b. strides( ) ) ;
237
+ let c = a. dot ( & b) ;
238
+ let reference = reference_mat_mul ( & a, & b) ;
239
+ let diff = ( & c - & reference) . mapv_into ( f64:: abs) ;
240
+
241
+ let rtol = 1e-7 ;
242
+ let atol = 1e-12 ;
243
+ let crtol = c. mapv ( |x| x. abs ( ) * rtol) ;
244
+ let tol = crtol + atol;
245
+ let tol_m_diff = & diff - & tol;
246
+ let maxdiff = * tol_m_diff. max ( ) ;
247
+ println ! ( "diff offset from tolerance level= {:.2e}" , maxdiff) ;
248
+ if maxdiff > 0. {
249
+ panic ! ( "results differ" ) ;
250
+ }
251
+ }
252
+ }
253
+ }
254
+ }
255
+
200
256
201
257
trait Utils {
202
258
type Elem ;
0 commit comments