Skip to content

Commit 5b333fb

Browse files
committed
Remove some allocations
1 parent 9cd7348 commit 5b333fb

File tree

1 file changed

+15
-22
lines changed

1 file changed

+15
-22
lines changed

src/svm/svc.rs

+15-22
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
322322
let (n, _) = x.shape();
323323
let mut y_hat: Vec<TX> = Array1::zeros(n);
324324

325+
let mut row = Vec::with_capacity(n);
325326
for i in 0..n {
326-
let row_pred: TX =
327-
self.predict_for_row(Vec::from_iterator(x.get_row(i).iterator(0).copied(), n));
327+
row.clear();
328+
row.extend(x.get_row(i).iterator(0).copied());
329+
let row_pred: TX = self.predict_for_row(&row);
328330
y_hat.set(i, row_pred);
329331
}
330332

331333
Ok(y_hat)
332334
}
333335

334-
fn predict_for_row(&self, x: Vec<TX>) -> TX {
336+
fn predict_for_row(&self, x: &[TX]) -> TX {
335337
let mut f = self.b.unwrap();
336338

337339
for i in 0..self.instances.as_ref().unwrap().len() {
@@ -472,14 +474,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
472474
let tol = self.parameters.tol;
473475
let good_enough = TX::from_i32(1000).unwrap();
474476

477+
let mut x = Vec::with_capacity(n);
475478
for _ in 0..self.parameters.epoch {
476479
for i in self.permutate(n) {
477-
self.process(
478-
i,
479-
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
480-
*self.y.get(i),
481-
&mut cache,
482-
);
480+
x.clear();
481+
x.extend(self.x.get_row(i).iterator(0).take(n).copied());
482+
self.process(i, &x, *self.y.get(i), &mut cache);
483483
loop {
484484
self.reprocess(tol, &mut cache);
485485
self.find_min_max_gradient();
@@ -511,24 +511,17 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
511511
let mut cp = 0;
512512
let mut cn = 0;
513513

514+
let mut x = Vec::with_capacity(n);
514515
for i in self.permutate(n) {
516+
x.clear();
517+
x.extend(self.x.get_row(i).iterator(0).take(n).copied());
515518
if *self.y.get(i) == TY::one() && cp < few {
516-
if self.process(
517-
i,
518-
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
519-
*self.y.get(i),
520-
cache,
521-
) {
519+
if self.process(i, &x, *self.y.get(i), cache) {
522520
cp += 1;
523521
}
524522
} else if *self.y.get(i) == TY::from(-1).unwrap()
525523
&& cn < few
526-
&& self.process(
527-
i,
528-
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
529-
*self.y.get(i),
530-
cache,
531-
)
524+
&& self.process(i, &x, *self.y.get(i), cache)
532525
{
533526
cn += 1;
534527
}
@@ -539,7 +532,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
539532
}
540533
}
541534

542-
fn process(&mut self, i: usize, x: Vec<TX>, y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
535+
fn process(&mut self, i: usize, x: &[TX], y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
543536
for j in 0..self.sv.len() {
544537
if self.sv[j].index == i {
545538
return true;

0 commit comments

Comments
 (0)