Skip to content

Commit aed3567

Browse files
committed
update tests
1 parent 762a99a commit aed3567

File tree

5 files changed

+70
-57
lines changed

5 files changed

+70
-57
lines changed

tests/ui/batching/batch_const.rs

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Problem: The user might want to pass either [f64; 4], (f64, f64, f64, f64), or S to the
2+
// function. All of these are valid (modulo we have to force the user to set the right repr).
3+
// Our current design doesn't allow users to specify those, so we will want at least one iteration.
4+
// However, for the sake of similarity to the current autodiff (where we'd also want a change),
5+
// leave it as is.
6+
7+
struct _S {
8+
x1: f64,
9+
x2: f64,
10+
x3: f64,
11+
x4: f64,
12+
}
13+
14+
#[batch(bsquare4, 4, Const, Leaf(8))]
15+
#[batch(vsquare4, 4, Const, Vector)]
16+
fn square(multiplier: f64, x: f64) -> f64 {
17+
x * x * multiplier
18+
}
19+
20+
fn main() {
21+
let vals = [23.1, 10.0, 100.0, 3.14];
22+
let expected = [square(3.14, vals[0]), square(3.14, vals[1]), square(3.14, vals[2]), square(3.14, vals[3])];
23+
let result1 = bsquare4(3.14, vals[0], vals[1], vals[2], vals[3]);
24+
let result2 = vsquare4(3.14, vals);
25+
assert_eq!(result.x1, expected[0]);
26+
assert_eq!(result.x2, expected[1]);
27+
assert_eq!(result.x3, expected[2]);
28+
assert_eq!(result.x4, expected[3]);
29+
assert_eq!(result2.x1, expected[0]);
30+
assert_eq!(result2.x2, expected[1]);
31+
assert_eq!(result2.x3, expected[2]);
32+
assert_eq!(result2.x4, expected[3]);
33+
}

tests/ui/batching/slice.rs

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// We want a batch size of 4.
2+
// The original function processes 2 elements a 64 bit, so for our vfoo we have an offset of 16 bytes.
3+
// Both vfoo and bfoo return [f64; 4].
4+
5+
#[batch(vfoo, 4, Leaf(16))]
6+
#[batch(bfoo, 4, Batch)]
7+
fn foo(x: &[f64]) -> f64 {
8+
assert!(x.len() == 2);
9+
x.iter().map(|&x| x * x).sum()
10+
}
11+
12+
fn main() {
13+
// 8 elements
14+
let x1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
15+
16+
let x2 = vec![1.0, 2.0];
17+
let x3 = vec![3.0, 4.0];
18+
let x4 = vec![5.0, 6.0];
19+
let x5 = vec![7.0, 8.0];
20+
21+
let mut res1 = [0.0;4];
22+
for i in 0..4 {
23+
res1[i] = foo(&x1[i..i + 1]);
24+
}
25+
26+
let res2: [f64; 4] = bfoo(&x2, &x3, &x4, &x5);
27+
28+
let res3: [f64; 4] = vfoo(&x1);
29+
30+
for i in 0..4 {
31+
assert_eq!(res1[i], res2[i]);
32+
assert_eq!(res1[i], res3[i]);
33+
}
34+
}

tests/ui/vectorize/vector_char-ptr.rs renamed to tests/ui/batching/vector_char-ptr.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// Showcasing a slightly more complex type.
2+
13

24
#[repr(C, packed)]
35
struct Foo {
@@ -15,7 +17,7 @@ struct Foo {
1517
// double res;
1618
//};
1719

18-
#[vectorize(df, Vector, 4)]
20+
#[batch(df, 4, Vector)]
1921
unsafe fn f(foo: *mut i32) {
2022
let xptr = foo.add(3) as *mut f64;
2123
let yptr = foo.add(5) as *mut f32;
@@ -25,10 +27,6 @@ unsafe fn f(foo: *mut i32) {
2527
*resptr = x * y;
2628
}
2729

28-
//void df(char *dfoo1, char *dfoo2, char *dfoo3, char *dfoo4) {
29-
// __enzyme_batch((void *)f, enzyme_width, 4, dfoo1, dfoo2, dfoo3, dfoo4);
30-
//}
31-
3230
fn main() {
3331
let foo1: Foo = Foo { [0,0,0], 10.0, 9.0, 0.0 };
3432
let foo2: Foo = Foo { [0,0,0], 99.0, 7.0, 0.0 };

tests/ui/vectorize/batch_sqnorm.rs

-20
This file was deleted.

tests/ui/vectorize/vector_square.rs

-32
This file was deleted.

0 commit comments

Comments
 (0)