Skip to content

Commit d9e9c9c

Browse files
jedbrownZuseZ4
authored andcommitted
autodiff: no_std support (switch std:: to core::)
I can now do this no a device function and the IR looks okay by eyeball. argo +enzyme rustc --release --target=nvptx64-nvidia-cuda -Zbuild-std -- --emit=llvm-ir
1 parent efad9fd commit d9e9c9c

6 files changed

+14
-14
lines changed

library/autodiff/src/gen.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream {
107107
res_inputs.push(input.clone());
108108

109109
match (item.header.mode, activity, is_ref_mut(&input)) {
110-
(Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => {
110+
(Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => {
111111
res_inputs.push(as_ref_mut(&input, "grad", true));
112112
add_inputs.push(as_ref_mut(&input, "grad", true));
113113
}
114-
(Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => {
114+
(Mode::Forward, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(false)) => {
115115
res_inputs.push(as_ref_mut(&input, "dual", false));
116116
add_inputs.push(as_ref_mut(&input, "dual", false));
117117
out_type.clone().map(|x| outputs.push(x));
@@ -203,9 +203,9 @@ pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream {
203203
};
204204

205205
let body = quote!({
206-
std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*));
206+
core::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*));
207207

208-
std::hint::black_box(unsafe { std::mem::zeroed() })
208+
core::hint::black_box(unsafe { core::mem::zeroed() })
209209
});
210210
let header = generate_header(&item);
211211

library/autodiff/tests/expand/forward_duplicated.expanded.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ fn square(a: &Vec<f32>, b: &mut f32) {
55
}
66
#[autodiff_into(Forward, Const, Duplicated, Duplicated)]
77
fn d_square(a: &Vec<f32>, dual_a: &Vec<f32>, b: &mut f32, grad_b: &mut f32) {
8-
std::hint::black_box((square(a, b), dual_a, grad_b));
9-
std::hint::black_box(unsafe { std::mem::zeroed() })
8+
core::hint::black_box((square(a, b), dual_a, grad_b));
9+
core::hint::black_box(unsafe { core::mem::zeroed() })
1010
}

library/autodiff/tests/expand/forward_duplicated_return.expanded.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ fn d_square2(
1010
b: &Vec<f32>,
1111
dual_b: &Vec<f32>,
1212
) -> (f32, f32, f32) {
13-
std::hint::black_box((square2(a, b), dual_a, dual_b));
14-
std::hint::black_box(unsafe { std::mem::zeroed() })
13+
core::hint::black_box((square2(a, b), dual_a, dual_b));
14+
core::hint::black_box(unsafe { core::mem::zeroed() })
1515
}

library/autodiff/tests/expand/reverse_duplicated.expanded.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ fn square(a: &Vec<f32>, b: &mut f32) {
55
}
66
#[autodiff_into(Reverse, Const, Duplicated, Duplicated)]
77
fn d_square(a: &Vec<f32>, grad_a: &mut Vec<f32>, b: &mut f32, grad_b: &f32) {
8-
std::hint::black_box((square(a, b), grad_a, grad_b));
9-
std::hint::black_box(unsafe { std::mem::zeroed() })
8+
core::hint::black_box((square(a, b), grad_a, grad_b));
9+
core::hint::black_box(unsafe { core::mem::zeroed() })
1010
}

library/autodiff/tests/expand/reverse_return_array.expanded.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 {
55
}
66
#[autodiff_into(Reverse, Active, Duplicated)]
77
fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) {
8-
std::hint::black_box((array(arr), grad_arr, tang_y));
9-
std::hint::black_box(unsafe { std::mem::zeroed() })
8+
core::hint::black_box((array(arr), grad_arr, tang_y));
9+
core::hint::black_box(unsafe { core::mem::zeroed() })
1010
}

library/autodiff/tests/expand/reverse_return_mixed.expanded.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ fn d_sqrt(
1212
d: f32,
1313
tang_y: f32,
1414
) -> (f32, f32) {
15-
std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y));
16-
std::hint::black_box(unsafe { std::mem::zeroed() })
15+
core::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y));
16+
core::hint::black_box(unsafe { core::mem::zeroed() })
1717
}

0 commit comments

Comments
 (0)