Skip to content

Commit 8726c99

Browse files
committed
only include rustc_codegen_llvm autodiff changes
1 parent fb4aebd commit 8726c99

File tree

24 files changed

+2283
-27
lines changed

24 files changed

+2283
-27
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2+
/// we create an `AutoDiffItem` which contains the source and target function names. The source
3+
/// is the function to which the autodiff attribute is applied, and the target is the function
4+
/// getting generated by us (with a name given by the user as the first autodiff arg).
5+
use crate::expand::typetree::TypeTree;
6+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
7+
8+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
9+
pub enum DiffMode {
10+
/// No autodiff is applied (usually used during error handling).
11+
Inactive,
12+
/// The primal function which we will differentiate.
13+
Source,
14+
/// The target function, to be created using forward mode AD.
15+
Forward,
16+
/// The target function, to be created using reverse mode AD.
17+
Reverse,
18+
/// The target function, to be created using forward mode AD.
19+
/// This target function will also be used as a source for higher order derivatives,
20+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
21+
ForwardFirst,
22+
/// The target function, to be created using reverse mode AD.
23+
/// This target function will also be used as a source for higher order derivatives,
24+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
25+
ReverseFirst,
26+
}
27+
28+
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
29+
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
30+
/// we add to the previous shadow value. To not surprise users, we picked different names.
31+
/// Dual numbers is also a quite well known name for forward mode AD types.
32+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
33+
pub enum DiffActivity {
34+
/// Implicit or Explicit () return type, so a special case of Const.
35+
None,
36+
/// Don't compute derivatives with respect to this input/output.
37+
Const,
38+
/// Reverse Mode, Compute derivatives for this scalar input/output.
39+
Active,
40+
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
41+
/// the original return value.
42+
ActiveOnly,
43+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
44+
/// with it.
45+
Dual,
46+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
47+
/// with it. Drop the code which updates the original input/output for maximum performance.
48+
DualOnly,
49+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
50+
Duplicated,
51+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
52+
/// Drop the code which updates the original input for maximum performance.
53+
DuplicatedOnly,
54+
/// All Integers must be Const, but these are used to mark the integer which represents the
55+
/// length of a slice/vec. This is used for safety checks on slices.
56+
FakeActivitySize,
57+
}
58+
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
59+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
60+
pub struct AutoDiffItem {
61+
/// The name of the function getting differentiated
62+
pub source: String,
63+
/// The name of the function being generated
64+
pub target: String,
65+
pub attrs: AutoDiffAttrs,
66+
/// Despribe the memory layout of input types
67+
pub inputs: Vec<TypeTree>,
68+
/// Despribe the memory layout of the output type
69+
pub output: TypeTree,
70+
}
71+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
72+
pub struct AutoDiffAttrs {
73+
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
74+
/// e.g. in the [JAX
75+
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
76+
pub mode: DiffMode,
77+
pub ret_activity: DiffActivity,
78+
pub input_activity: Vec<DiffActivity>,
79+
}

compiler/rustc_ast/src/expand/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
77
use crate::MetaItem;
88

99
pub mod allocator;
10+
pub mod autodiff_attrs;
11+
pub mod typetree;
1012

1113
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1214
pub struct StrippedCfgItem<ModId = DefId> {
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use std::fmt;
2+
3+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
4+
5+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
6+
pub enum Kind {
7+
Anything,
8+
Integer,
9+
Pointer,
10+
Half,
11+
Float,
12+
Double,
13+
Unknown,
14+
}
15+
16+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
17+
pub struct TypeTree(pub Vec<Type>);
18+
19+
impl TypeTree {
20+
pub fn new() -> Self {
21+
Self(Vec::new())
22+
}
23+
pub fn all_ints() -> Self {
24+
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
25+
}
26+
pub fn int(size: usize) -> Self {
27+
let mut ints = Vec::with_capacity(size);
28+
for i in 0..size {
29+
ints.push(Type {
30+
offset: i as isize,
31+
size: 1,
32+
kind: Kind::Integer,
33+
child: TypeTree::new(),
34+
});
35+
}
36+
Self(ints)
37+
}
38+
}
39+
40+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
41+
pub struct FncTree {
42+
pub args: Vec<TypeTree>,
43+
pub ret: TypeTree,
44+
}
45+
46+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
47+
pub struct Type {
48+
pub offset: isize,
49+
pub size: usize,
50+
pub kind: Kind,
51+
pub child: TypeTree,
52+
}
53+
54+
impl Type {
55+
pub fn add_offset(self, add: isize) -> Self {
56+
let offset = match self.offset {
57+
-1 => add,
58+
x => add + x,
59+
};
60+
61+
Self { size: self.size, kind: self.kind, child: self.child, offset }
62+
}
63+
}
64+
65+
impl fmt::Display for Type {
66+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67+
<Self as fmt::Debug>::fmt(self, f)
68+
}
69+
}

compiler/rustc_codegen_llvm/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
5454
codegen_llvm_run_passes = failed to run LLVM passes
5555
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
5656
57+
codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error}
58+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
59+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
60+
5761
codegen_llvm_sanitizer_memtag_requires_mte =
5862
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
5963

compiler/rustc_codegen_llvm/src/attributes.rs

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//! Set and unset common attributes on LLVM values.
22
33
use rustc_attr::{InlineAttr, InstructionSetAttr, OptimizeAttr};
4+
// FIXME(ZuseZ4): Re-enable once the middle-end is merged.
5+
//use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
46
use rustc_codegen_ssa::traits::*;
57
use rustc_hir::def_id::DefId;
68
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, PatchableFunctionEntry};
@@ -333,6 +335,8 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>(
333335
instance: ty::Instance<'tcx>,
334336
) {
335337
let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id());
338+
// FIXME(ZuseZ4): Re-enable once the middle-end is merged.
339+
//let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id());
336340

337341
let mut to_add = SmallVec::<[_; 16]>::new();
338342

@@ -350,6 +354,9 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>(
350354
let inline =
351355
if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) {
352356
InlineAttr::Hint
357+
// FIXME(ZuseZ4): re-enable once the middle-end is merged.
358+
//} else if autodiff_attrs.is_active() {
359+
// InlineAttr::Never
353360
} else {
354361
codegen_fn_attrs.inline
355362
};

compiler/rustc_codegen_llvm/src/back/lto.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,12 @@ pub(crate) fn run_pass_manager(
616616
}
617617
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
618618
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
619-
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?;
619+
620+
// We will run this again with different values in the context of automatic differentiation.
621+
let first_run = true;
622+
let noop = false;
623+
debug!("running llvm pm opt pipeline");
624+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop)?;
620625
}
621626
debug!("lto done");
622627
Ok(())
@@ -723,7 +728,12 @@ pub(crate) unsafe fn optimize_thin_module(
723728
let llcx = unsafe { llvm::LLVMRustContextCreate(cgcx.fewer_names) };
724729
let llmod_raw = parse_module(llcx, module_name, thin_module.data(), dcx)? as *const _;
725730
let mut module = ModuleCodegen {
726-
module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) },
731+
module_llvm: ModuleLlvm {
732+
llmod_raw,
733+
llcx,
734+
tm: ManuallyDrop::new(tm),
735+
typetrees: Default::default(),
736+
},
727737
name: thin_module.name().to_string(),
728738
kind: ModuleKind::Regular,
729739
};

0 commit comments

Comments
 (0)