Skip to content

Commit 5863b42

Browse files
authored
Rollup merge of #139465 - EnzymeAD:autodiff-sret, r=oli-obk
add sret handling for scalar autodiff r? `@oli-obk` Fixing one of the todo's which I left in my previous batching PR. This one handles sret for scalar autodiff. `sret` mostly shows up when we try to return a lot of scalar floats. People often start testing autodiff which toy functions which just use a few scalars as inputs and outputs, and those were the most likely to be affected by this issue. So this fix should make learning/teaching hopefully a bit easier. Tracking: - #124509
2 parents 0178254 + ca5bea3 commit 5863b42

File tree

5 files changed

+73
-2
lines changed

5 files changed

+73
-2
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ pub struct AutoDiffAttrs {
9292
pub input_activity: Vec<DiffActivity>,
9393
}
9494

95+
impl AutoDiffAttrs {
96+
pub fn has_primal_ret(&self) -> bool {
97+
matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
98+
}
99+
}
100+
95101
impl DiffMode {
96102
pub fn is_rev(&self) -> bool {
97103
matches!(self, DiffMode::Reverse)

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+22-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,23 @@ fn compute_enzyme_fn_ty<'ll>(
201201
}
202202

203203
if attrs.width == 1 {
204-
todo!("Handle sret for scalar ad");
204+
// Enzyme returns a struct of style:
205+
// `{ original_ret(if requested), float, float, ... }`
206+
let mut struct_elements = vec![];
207+
if attrs.has_primal_ret() {
208+
struct_elements.push(inner_ret_ty);
209+
}
210+
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
211+
// and therefore part of the return struct.
212+
let param_tys = cx.func_params_types(fn_ty);
213+
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
214+
if matches!(act, DiffActivity::Active) {
215+
// Now find the float type at position i based on the fn_ty,
216+
// to know what (f16/f32/f64/...) to add to the struct.
217+
struct_elements.push(param_ty);
218+
}
219+
}
220+
ret_ty = cx.type_struct(&struct_elements, false);
205221
} else {
206222
// First we check if we also have to deal with the primal return.
207223
match attrs.mode {
@@ -388,7 +404,11 @@ fn generate_enzyme_call<'ll>(
388404
// now store the result of the enzyme call into the sret pointer.
389405
let sret_ptr = outer_args[0];
390406
let call_ty = cx.val_ty(call);
391-
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
407+
if attrs.width == 1 {
408+
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
409+
} else {
410+
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
411+
}
392412
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
393413
}
394414
builder.ret_void();
File renamed without changes.
File renamed without changes.

tests/codegen/autodiff/sret.rs

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
5+
// This test is almost identical to the scalar.rs one,
6+
// but we intentionally add a few more floats.
7+
// `df` would ret `{ f64, f32, f32 }`, but is lowered as an sret.
8+
// We therefore use this test to verify some of our sret handling.
9+
10+
#![feature(autodiff)]
11+
12+
use std::autodiff::autodiff;
13+
14+
#[no_mangle]
15+
#[autodiff(df, Reverse, Active, Active, Active)]
16+
fn primal(x: f32, y: f32) -> f64 {
17+
(x * x * y) as f64
18+
}
19+
20+
// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y)
21+
// CHECK-NEXT:start:
22+
// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y)
23+
// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0
24+
// CHECK-NEXT: store double %.elt, ptr %_0, align 8
25+
// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8
26+
// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1
27+
// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8
28+
// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12
29+
// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2
30+
// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4
31+
// CHECK-NEXT: ret void
32+
// CHECK-NEXT:}
33+
34+
fn main() {
35+
let x = std::hint::black_box(3.0);
36+
let y = std::hint::black_box(2.5);
37+
let scalar = std::hint::black_box(1.0);
38+
let (r1, r2, r3) = df(x, y, scalar);
39+
// 3*3*1.5 = 22.5
40+
assert_eq!(r1, 22.5);
41+
// 2*x*y = 2*3*2.5 = 15.0
42+
assert_eq!(r2, 15.0);
43+
// x*x*1 = 3*3 = 9
44+
assert_eq!(r3, 9.0);
45+
}

0 commit comments

Comments
 (0)