Skip to content

Commit 2f84a57

Browse files
committed
add batching frontend
1 parent 687a6a7 commit 2f84a57

File tree

14 files changed

+817
-2
lines changed

14 files changed

+817
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2+
//! we create an [`BatchItem`] which contains the source and target function names. The source
3+
//! is the function to which the autodiff attribute is applied, and the target is the function
4+
//! getting generated by us (with a name given by the user as the first autodiff arg).
5+
6+
use std::fmt::{self, Display, Formatter};
7+
use std::str::FromStr;
8+
9+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
10+
11+
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
12+
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
13+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
14+
pub enum BatchMode {
15+
/// No vectorization is applied (used during error handling).
16+
Error,
17+
/// The primal function which we will vectorize.
18+
Source,
19+
}
20+
21+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
22+
pub enum BatchActivity {
23+
/// Don't batch this argument
24+
Const,
25+
/// Unsafe.
26+
Leaf,
27+
/// Just receive this argument N times.
28+
Vector,
29+
}
30+
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
31+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
32+
pub struct BatchItem {
33+
/// The name of the function getting differentiated
34+
pub source: String,
35+
/// The name of the function being generated
36+
pub target: String,
37+
pub attrs: BatchAttrs,
38+
}
39+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
40+
pub struct BatchAttrs {
41+
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
42+
/// e.g. in the [JAX
43+
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
44+
pub mode: BatchMode,
45+
pub width: usize,
46+
pub input_activity: Vec<BatchActivity>,
47+
}
48+
49+
impl Display for BatchMode {
50+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
51+
match self {
52+
BatchMode::Error => write!(f, "Error"),
53+
BatchMode::Source => write!(f, "Source"),
54+
}
55+
}
56+
}
57+
58+
impl Display for BatchActivity {
59+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60+
match self {
61+
BatchActivity::Const => write!(f, "Const"),
62+
BatchActivity::Leaf => write!(f, "Leaf"),
63+
BatchActivity::Vector => write!(f, "Vector"),
64+
}
65+
}
66+
}
67+
68+
impl FromStr for BatchMode {
69+
type Err = ();
70+
71+
fn from_str(s: &str) -> Result<BatchMode, ()> {
72+
match s {
73+
"Error" => Ok(BatchMode::Error),
74+
"Source" => Ok(BatchMode::Source),
75+
_ => Err(()),
76+
}
77+
}
78+
}
79+
impl FromStr for BatchActivity {
80+
type Err = ();
81+
82+
fn from_str(s: &str) -> Result<BatchActivity, ()> {
83+
match s {
84+
"Const" => Ok(BatchActivity::Const),
85+
"Leaf" => Ok(BatchActivity::Leaf),
86+
"Vector" => Ok(BatchActivity::Vector),
87+
_ => Err(()),
88+
}
89+
}
90+
}
91+
92+
impl BatchAttrs {
93+
94+
pub fn error() -> Self {
95+
BatchAttrs {
96+
mode: BatchMode::Error,
97+
width: 0,
98+
input_activity: Vec::new(),
99+
}
100+
}
101+
pub fn source() -> Self {
102+
BatchAttrs {
103+
mode: BatchMode::Source,
104+
width: 0,
105+
input_activity: Vec::new(),
106+
}
107+
}
108+
109+
pub fn is_active(&self) -> bool {
110+
self.mode != BatchMode::Error
111+
}
112+
113+
pub fn is_source(&self) -> bool {
114+
self.mode == BatchMode::Source
115+
}
116+
pub fn apply_batch(&self) -> bool {
117+
!matches!(self.mode, BatchMode::Error | BatchMode::Source)
118+
}
119+
120+
pub fn into_item(self, source: String, target: String) -> BatchItem {
121+
BatchItem { source, target, attrs: self }
122+
}
123+
}
124+
125+
impl fmt::Display for BatchItem {
126+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127+
write!(f, "Batching {} -> {}", self.source, self.target)?;
128+
write!(f, " with attributes: {:?}", self.attrs)
129+
}
130+
}

compiler/rustc_ast/src/expand/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::MetaItem;
88

99
pub mod allocator;
1010
pub mod autodiff_attrs;
11+
pub mod batch_attrs;
1112
pub mod typetree;
1213

1314
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]

compiler/rustc_builtin_macros/messages.ftl

+11-2
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,24 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
6969
builtin_macros_assert_requires_expression = macro requires an expression as an argument
7070
.suggestion = try removing semicolon
7171
72-
builtin_macros_autodiff = autodiff must be applied to function
72+
builtin_macros_batch = batch must be applied to a function
73+
builtin_macros_batch_missing_config = batch requires at least a name and mode
74+
builtin_macros_batch_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
75+
builtin_macros_batch_mode_activity = {$act} can not be used in {$mode} Mode
76+
builtin_macros_batch_not_build = this rustc version does not support batch
77+
builtin_macros_batch_number_activities = expected {$expected} activities, but found {$found}
78+
builtin_macros_batch_ty_activity = {$act} can not be used for this type
79+
builtin_macros_batch_unknown_activity = did not recognize Activity: `{$act}`
80+
81+
builtin_macros_autodiff = autodiff must be applied to a function
7382
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
7483
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
7584
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
7685
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7786
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
7887
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
79-
8088
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
89+
8190
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
8291
.label = not applicable here
8392
.label2 = not a `struct`, `enum` or `union`

0 commit comments

Comments
 (0)