forked from coreylowman/dfdx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfill.rs
84 lines (74 loc) · 2.19 KB
/
fill.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
use super::{AllocateZeros, Cpu};
use crate::arrays::CountElements;
use std::boxed::Box;
/// Fills all elements with the specified function
pub trait FillElements<T: CountElements>: Sized + AllocateZeros {
fn fill<F: FnMut(&mut T::Dtype)>(out: &mut T, f: &mut F);
fn filled<F: FnMut(&mut T::Dtype)>(f: &mut F) -> Box<T> {
let mut out = Self::zeros();
Self::fill(&mut out, f);
out
}
}
impl FillElements<f32> for Cpu {
fn fill<F: FnMut(&mut f32)>(out: &mut f32, f: &mut F) {
f(out)
}
}
impl<T: CountElements, const M: usize> FillElements<[T; M]> for Cpu
where
Self: FillElements<T>,
{
fn fill<F: FnMut(&mut T::Dtype)>(out: &mut [T; M], f: &mut F) {
for out_i in out.iter_mut() {
Self::fill(out_i, f);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrays::ZeroElements;
use rand::{thread_rng, Rng};
#[test]
fn test_fill_rng() {
let mut rng = thread_rng();
let mut t: [f32; 5] = ZeroElements::ZEROS;
Cpu::fill(&mut t, &mut |v| *v = rng.gen_range(0.0..1.0));
for &item in t.iter() {
assert!((0.0..1.0).contains(&item));
}
}
#[test]
fn test_0d_fill() {
let mut t: f32 = ZeroElements::ZEROS;
Cpu::fill(&mut t, &mut |v| *v = 1.0);
assert_eq!(t, 1.0);
Cpu::fill(&mut t, &mut |v| *v = 2.0);
assert_eq!(t, 2.0);
}
#[test]
fn test_1d_fill() {
let mut t: [f32; 5] = ZeroElements::ZEROS;
Cpu::fill(&mut t, &mut |v| *v = 1.0);
assert_eq!(t, [1.0; 5]);
Cpu::fill(&mut t, &mut |v| *v = 2.0);
assert_eq!(t, [2.0; 5]);
}
#[test]
fn test_2d_fill() {
let mut t: [[f32; 3]; 5] = ZeroElements::ZEROS;
Cpu::fill(&mut t, &mut |v| *v = 1.0);
assert_eq!(t, [[1.0; 3]; 5]);
Cpu::fill(&mut t, &mut |v| *v = 2.0);
assert_eq!(t, [[2.0; 3]; 5]);
}
#[test]
fn test_3d_fill() {
let mut t: [[[f32; 2]; 3]; 5] = ZeroElements::ZEROS;
Cpu::fill(&mut t, &mut |v| *v = 1.0);
assert_eq!(t, [[[1.0; 2]; 3]; 5]);
Cpu::fill(&mut t, &mut |v| *v = 2.0);
assert_eq!(t, [[[2.0; 2]; 3]; 5]);
}
}