|
| 1 | +/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, |
| 2 | +/// we create an `AutoDiffItem` which contains the source and target function names. The source |
| 3 | +/// is the function to which the autodiff attribute is applied, and the target is the function |
| 4 | +/// getting generated by us (with a name given by the user as the first autodiff arg). |
| 5 | +use crate::expand::typetree::TypeTree; |
| 6 | +use crate::expand::{Decodable, Encodable, HashStable_Generic}; |
| 7 | + |
| 8 | +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] |
| 9 | +pub enum DiffMode { |
| 10 | + /// No autodiff is applied (usually used during error handling). |
| 11 | + Inactive, |
| 12 | + /// The primal function which we will differentiate. |
| 13 | + Source, |
| 14 | + /// The target function, to be created using forward mode AD. |
| 15 | + Forward, |
| 16 | + /// The target function, to be created using reverse mode AD. |
| 17 | + Reverse, |
| 18 | + /// The target function, to be created using forward mode AD. |
| 19 | + /// This target function will also be used as a source for higher order derivatives, |
| 20 | + /// so compute it before all Forward/Reverse targets and optimize it through llvm. |
| 21 | + ForwardFirst, |
| 22 | + /// The target function, to be created using reverse mode AD. |
| 23 | + /// This target function will also be used as a source for higher order derivatives, |
| 24 | + /// so compute it before all Forward/Reverse targets and optimize it through llvm. |
| 25 | + ReverseFirst, |
| 26 | +} |
| 27 | + |
| 28 | +/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. |
| 29 | +/// However, under forward mode we overwrite the previous shadow value, while for reverse mode |
| 30 | +/// we add to the previous shadow value. To not surprise users, we picked different names. |
| 31 | +/// Dual numbers is also a quite well known name for forward mode AD types. |
| 32 | +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] |
| 33 | +pub enum DiffActivity { |
| 34 | + /// Implicit or Explicit () return type, so a special case of Const. |
| 35 | + None, |
| 36 | + /// Don't compute derivatives with respect to this input/output. |
| 37 | + Const, |
| 38 | + /// Reverse Mode, Compute derivatives for this scalar input/output. |
| 39 | + Active, |
| 40 | + /// Reverse Mode, Compute derivatives for this scalar output, but don't compute |
| 41 | + /// the original return value. |
| 42 | + ActiveOnly, |
| 43 | + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument |
| 44 | + /// with it. |
| 45 | + Dual, |
| 46 | + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument |
| 47 | + /// with it. Drop the code which updates the original input/output for maximum performance. |
| 48 | + DualOnly, |
| 49 | + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. |
| 50 | + Duplicated, |
| 51 | + /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. |
| 52 | + /// Drop the code which updates the original input for maximum performance. |
| 53 | + DuplicatedOnly, |
| 54 | + /// All Integers must be Const, but these are used to mark the integer which represents the |
| 55 | + /// length of a slice/vec. This is used for safety checks on slices. |
| 56 | + FakeActivitySize, |
| 57 | +} |
| 58 | +/// We generate one of these structs for each `#[autodiff(...)]` attribute. |
| 59 | +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] |
| 60 | +pub struct AutoDiffItem { |
| 61 | + /// The name of the function getting differentiated |
| 62 | + pub source: String, |
| 63 | + /// The name of the function being generated |
| 64 | + pub target: String, |
| 65 | + pub attrs: AutoDiffAttrs, |
| 66 | + /// Despribe the memory layout of input types |
| 67 | + pub inputs: Vec<TypeTree>, |
| 68 | + /// Despribe the memory layout of the output type |
| 69 | + pub output: TypeTree, |
| 70 | +} |
| 71 | +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] |
| 72 | +pub struct AutoDiffAttrs { |
| 73 | + /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and |
| 74 | + /// e.g. in the [JAX |
| 75 | + /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). |
| 76 | + pub mode: DiffMode, |
| 77 | + pub ret_activity: DiffActivity, |
| 78 | + pub input_activity: Vec<DiffActivity>, |
| 79 | +} |
0 commit comments