@@ -322,19 +322,26 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
322
322
let ( n, _) = x. shape ( ) ;
323
323
let mut y_hat: Vec < TX > = Array1 :: zeros ( n) ;
324
324
325
+ let mut row = Vec :: with_capacity ( n) ;
325
326
for i in 0 ..n {
326
- let row_pred: TX =
327
- self . predict_for_row ( Vec :: from_iterator ( x. get_row ( i) . iterator ( 0 ) . copied ( ) , n) ) ;
327
+ row. clear ( ) ;
328
+ row. extend ( x. get_row ( i) . iterator ( 0 ) . copied ( ) ) ;
329
+ let row_pred: TX = self . predict_for_row ( & row) ;
328
330
y_hat. set ( i, row_pred) ;
329
331
}
330
332
331
333
Ok ( y_hat)
332
334
}
333
335
334
- fn predict_for_row ( & self , x : Vec < TX > ) -> TX {
336
+ fn predict_for_row ( & self , x : & [ TX ] ) -> TX {
335
337
let mut f = self . b . unwrap ( ) ;
336
338
339
+ let xi: Vec < _ > = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
337
340
for i in 0 ..self . instances . as_ref ( ) . unwrap ( ) . len ( ) {
341
+ let xj: Vec < _ > = self . instances . as_ref ( ) . unwrap ( ) [ i]
342
+ . iter ( )
343
+ . map ( |e| e. to_f64 ( ) . unwrap ( ) )
344
+ . collect ( ) ;
338
345
f += self . w . as_ref ( ) . unwrap ( ) [ i]
339
346
* TX :: from (
340
347
self . parameters
@@ -343,13 +350,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
343
350
. kernel
344
351
. as_ref ( )
345
352
. unwrap ( )
346
- . apply (
347
- & x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
348
- & self . instances . as_ref ( ) . unwrap ( ) [ i]
349
- . iter ( )
350
- . map ( |e| e. to_f64 ( ) . unwrap ( ) )
351
- . collect ( ) ,
352
- )
353
+ . apply ( & xi, & xj)
353
354
. unwrap ( ) ,
354
355
)
355
356
. unwrap ( ) ;
@@ -472,14 +473,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
472
473
let tol = self . parameters . tol ;
473
474
let good_enough = TX :: from_i32 ( 1000 ) . unwrap ( ) ;
474
475
476
+ let mut x = Vec :: with_capacity ( n) ;
475
477
for _ in 0 ..self . parameters . epoch {
476
478
for i in self . permutate ( n) {
477
- self . process (
478
- i,
479
- Vec :: from_iterator ( self . x . get_row ( i) . iterator ( 0 ) . copied ( ) , n) ,
480
- * self . y . get ( i) ,
481
- & mut cache,
482
- ) ;
479
+ x. clear ( ) ;
480
+ x. extend ( self . x . get_row ( i) . iterator ( 0 ) . take ( n) . copied ( ) ) ;
481
+ self . process ( i, & x, * self . y . get ( i) , & mut cache) ;
483
482
loop {
484
483
self . reprocess ( tol, & mut cache) ;
485
484
self . find_min_max_gradient ( ) ;
@@ -511,24 +510,17 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
511
510
let mut cp = 0 ;
512
511
let mut cn = 0 ;
513
512
513
+ let mut x = Vec :: with_capacity ( n) ;
514
514
for i in self . permutate ( n) {
515
+ x. clear ( ) ;
516
+ x. extend ( self . x . get_row ( i) . iterator ( 0 ) . take ( n) . copied ( ) ) ;
515
517
if * self . y . get ( i) == TY :: one ( ) && cp < few {
516
- if self . process (
517
- i,
518
- Vec :: from_iterator ( self . x . get_row ( i) . iterator ( 0 ) . copied ( ) , n) ,
519
- * self . y . get ( i) ,
520
- cache,
521
- ) {
518
+ if self . process ( i, & x, * self . y . get ( i) , cache) {
522
519
cp += 1 ;
523
520
}
524
521
} else if * self . y . get ( i) == TY :: from ( -1 ) . unwrap ( )
525
522
&& cn < few
526
- && self . process (
527
- i,
528
- Vec :: from_iterator ( self . x . get_row ( i) . iterator ( 0 ) . copied ( ) , n) ,
529
- * self . y . get ( i) ,
530
- cache,
531
- )
523
+ && self . process ( i, & x, * self . y . get ( i) , cache)
532
524
{
533
525
cn += 1 ;
534
526
}
@@ -539,7 +531,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
539
531
}
540
532
}
541
533
542
- fn process ( & mut self , i : usize , x : Vec < TX > , y : TY , cache : & mut Cache < TX , TY , X , Y > ) -> bool {
534
+ fn process ( & mut self , i : usize , x : & [ TX ] , y : TY , cache : & mut Cache < TX , TY , X , Y > ) -> bool {
543
535
for j in 0 ..self . sv . len ( ) {
544
536
if self . sv [ j] . index == i {
545
537
return true ;
@@ -551,15 +543,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
551
543
let mut cache_values: Vec < ( ( usize , usize ) , TX ) > = Vec :: new ( ) ;
552
544
553
545
for v in self . sv . iter ( ) {
546
+ let xi: Vec < _ > = v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
547
+ let xj: Vec < _ > = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
554
548
let k = self
555
549
. parameters
556
550
. kernel
557
551
. as_ref ( )
558
552
. unwrap ( )
559
- . apply (
560
- & v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
561
- & x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
562
- )
553
+ . apply ( & xi, & xj)
563
554
. unwrap ( ) ;
564
555
cache_values. push ( ( ( i, v. index ) , TX :: from ( k) . unwrap ( ) ) ) ;
565
556
g -= v. alpha * k;
@@ -578,7 +569,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
578
569
cache. insert ( v. 0 , v. 1 . to_f64 ( ) . unwrap ( ) ) ;
579
570
}
580
571
581
- let x_f64 = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
572
+ let x_f64: Vec < _ > = x. iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
582
573
let k_v = self
583
574
. parameters
584
575
. kernel
@@ -701,8 +692,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
701
692
let km = sv1. k ;
702
693
let gm = sv1. grad ;
703
694
let mut best = 0f64 ;
695
+ let xi: Vec < _ > = sv1. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
704
696
for i in 0 ..self . sv . len ( ) {
705
697
let v = & self . sv [ i] ;
698
+ let xj: Vec < _ > = v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
706
699
let z = v. grad - gm;
707
700
let k = cache. get (
708
701
sv1,
@@ -711,10 +704,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
711
704
. kernel
712
705
. as_ref ( )
713
706
. unwrap ( )
714
- . apply (
715
- & sv1. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
716
- & v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
717
- )
707
+ . apply ( & xi, & xj)
718
708
. unwrap ( ) ,
719
709
) ;
720
710
let mut curv = km + v. k - 2f64 * k;
@@ -732,6 +722,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
732
722
}
733
723
}
734
724
725
+ let xi: Vec < _ > = self . sv [ idx_1]
726
+ . x
727
+ . iter ( )
728
+ . map ( |e| e. to_f64 ( ) . unwrap ( ) )
729
+ . collect :: < Vec < _ > > ( ) ;
730
+
735
731
idx_2. map ( |idx_2| {
736
732
(
737
733
idx_1,
@@ -742,16 +738,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
742
738
. as_ref ( )
743
739
. unwrap ( )
744
740
. apply (
745
- & self . sv [ idx_1]
746
- . x
747
- . iter ( )
748
- . map ( |e| e. to_f64 ( ) . unwrap ( ) )
749
- . collect ( ) ,
741
+ & xi,
750
742
& self . sv [ idx_2]
751
743
. x
752
744
. iter ( )
753
745
. map ( |e| e. to_f64 ( ) . unwrap ( ) )
754
- . collect ( ) ,
746
+ . collect :: < Vec < _ > > ( ) ,
755
747
)
756
748
. unwrap ( )
757
749
} ) ,
@@ -765,8 +757,11 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
765
757
let km = sv2. k ;
766
758
let gm = sv2. grad ;
767
759
let mut best = 0f64 ;
760
+
761
+ let xi: Vec < _ > = sv2. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
768
762
for i in 0 ..self . sv . len ( ) {
769
763
let v = & self . sv [ i] ;
764
+ let xj: Vec < _ > = v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
770
765
let z = gm - v. grad ;
771
766
let k = cache. get (
772
767
sv2,
@@ -775,10 +770,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
775
770
. kernel
776
771
. as_ref ( )
777
772
. unwrap ( )
778
- . apply (
779
- & sv2. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
780
- & v. x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
781
- )
773
+ . apply ( & xi, & xj)
782
774
. unwrap ( ) ,
783
775
) ;
784
776
let mut curv = km + v. k - 2f64 * k;
@@ -797,6 +789,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
797
789
}
798
790
}
799
791
792
+ let xj: Vec < _ > = self . sv [ idx_2]
793
+ . x
794
+ . iter ( )
795
+ . map ( |e| e. to_f64 ( ) . unwrap ( ) )
796
+ . collect ( ) ;
797
+
800
798
idx_1. map ( |idx_1| {
801
799
(
802
800
idx_1,
@@ -811,12 +809,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
811
809
. x
812
810
. iter ( )
813
811
. map ( |e| e. to_f64 ( ) . unwrap ( ) )
814
- . collect ( ) ,
815
- & self . sv [ idx_2]
816
- . x
817
- . iter ( )
818
- . map ( |e| e. to_f64 ( ) . unwrap ( ) )
819
- . collect ( ) ,
812
+ . collect :: < Vec < _ > > ( ) ,
813
+ & xj,
820
814
)
821
815
. unwrap ( )
822
816
} ) ,
@@ -835,12 +829,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
835
829
. x
836
830
. iter ( )
837
831
. map ( |e| e. to_f64 ( ) . unwrap ( ) )
838
- . collect ( ) ,
832
+ . collect :: < Vec < _ > > ( ) ,
839
833
& self . sv [ idx_2]
840
834
. x
841
835
. iter ( )
842
836
. map ( |e| e. to_f64 ( ) . unwrap ( ) )
843
- . collect ( ) ,
837
+ . collect :: < Vec < _ > > ( ) ,
844
838
)
845
839
. unwrap ( ) ,
846
840
) ) ,
@@ -895,18 +889,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
895
889
self . sv [ v1] . alpha -= step. to_f64 ( ) . unwrap ( ) ;
896
890
self . sv [ v2] . alpha += step. to_f64 ( ) . unwrap ( ) ;
897
891
892
+ let xi_v1: Vec < _ > = self . sv [ v1] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
893
+ let xi_v2: Vec < _ > = self . sv [ v2] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
898
894
for i in 0 ..self . sv . len ( ) {
895
+ let xj: Vec < _ > = self . sv [ i] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ;
899
896
let k2 = cache. get (
900
897
& self . sv [ v2] ,
901
898
& self . sv [ i] ,
902
899
self . parameters
903
900
. kernel
904
901
. as_ref ( )
905
902
. unwrap ( )
906
- . apply (
907
- & self . sv [ v2] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
908
- & self . sv [ i] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
909
- )
903
+ . apply ( & xi_v2, & xj)
910
904
. unwrap ( ) ,
911
905
) ;
912
906
let k1 = cache. get (
@@ -916,10 +910,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
916
910
. kernel
917
911
. as_ref ( )
918
912
. unwrap ( )
919
- . apply (
920
- & self . sv [ v1] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
921
- & self . sv [ i] . x . iter ( ) . map ( |e| e. to_f64 ( ) . unwrap ( ) ) . collect ( ) ,
922
- )
913
+ . apply ( & xi_v1, & xj)
923
914
. unwrap ( ) ,
924
915
) ;
925
916
self . sv [ i] . grad -= step. to_f64 ( ) . unwrap ( ) * ( k2 - k1) ;
0 commit comments