Skip to content

Commit 04d4978

Browse files
committed
Single commit with all changes
1 parent 730d5d4 commit 04d4978

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+4194
-56
lines changed

.gitmodules

+4
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,7 @@
4747
path = src/tools/rustc-perf
4848
url = https://github.com/rust-lang/rustc-perf.git
4949
shallow = true
50+
[submodule "src/tools/enzyme"]
51+
path = src/tools/enzyme
52+
url = [email protected]:EnzymeAD/Enzyme.git
53+
shallow = true

Cargo.lock

+2
Original file line numberDiff line numberDiff line change
@@ -4148,6 +4148,7 @@ dependencies = [
41484148
name = "rustc_monomorphize"
41494149
version = "0.0.0"
41504150
dependencies = [
4151+
"rustc_ast",
41514152
"rustc_data_structures",
41524153
"rustc_errors",
41534154
"rustc_fluent_macro",
@@ -4156,6 +4157,7 @@ dependencies = [
41564157
"rustc_middle",
41574158
"rustc_session",
41584159
"rustc_span",
4160+
"rustc_symbol_mangling",
41594161
"rustc_target",
41604162
"serde",
41614163
"serde_json",

compiler/rustc_ast/src/ast.rs

+7
Original file line numberDiff line numberDiff line change
@@ -2729,6 +2729,13 @@ impl FnRetTy {
27292729
FnRetTy::Ty(ty) => ty.span,
27302730
}
27312731
}
2732+
2733+
pub fn has_ret(&self) -> bool {
2734+
match self {
2735+
FnRetTy::Default(_) => false,
2736+
FnRetTy::Ty(_) => true,
2737+
}
2738+
}
27322739
}
27332740

27342741
#[derive(Clone, Copy, PartialEq, Encodable, Decodable, Debug)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
use std::fmt::{self, Display, Formatter};
2+
use std::str::FromStr;
3+
4+
use crate::expand::typetree::TypeTree;
5+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
6+
use crate::ptr::P;
7+
use crate::{Ty, TyKind};
8+
9+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
10+
pub enum DiffMode {
11+
Inactive,
12+
Source,
13+
Forward,
14+
Reverse,
15+
ForwardFirst,
16+
ReverseFirst,
17+
}
18+
19+
pub fn is_rev(mode: DiffMode) -> bool {
20+
match mode {
21+
DiffMode::Reverse | DiffMode::ReverseFirst => true,
22+
_ => false,
23+
}
24+
}
25+
pub fn is_fwd(mode: DiffMode) -> bool {
26+
match mode {
27+
DiffMode::Forward | DiffMode::ForwardFirst => true,
28+
_ => false,
29+
}
30+
}
31+
32+
impl Display for DiffMode {
33+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34+
match self {
35+
DiffMode::Inactive => write!(f, "Inactive"),
36+
DiffMode::Source => write!(f, "Source"),
37+
DiffMode::Forward => write!(f, "Forward"),
38+
DiffMode::Reverse => write!(f, "Reverse"),
39+
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
40+
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
41+
}
42+
}
43+
}
44+
45+
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
46+
if activity == DiffActivity::None {
47+
// Only valid if primal returns (), but we can't check that here.
48+
return true;
49+
}
50+
match mode {
51+
DiffMode::Inactive => false,
52+
DiffMode::Source => false,
53+
DiffMode::Forward | DiffMode::ForwardFirst => {
54+
activity == DiffActivity::Dual
55+
|| activity == DiffActivity::DualOnly
56+
|| activity == DiffActivity::Const
57+
}
58+
DiffMode::Reverse | DiffMode::ReverseFirst => {
59+
activity == DiffActivity::Const
60+
|| activity == DiffActivity::Active
61+
|| activity == DiffActivity::ActiveOnly
62+
}
63+
}
64+
}
65+
fn is_ptr_or_ref(ty: &Ty) -> bool {
66+
match ty.kind {
67+
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
68+
_ => false,
69+
}
70+
}
71+
// TODO We should make this more robust to also
72+
// accept aliases of f32 and f64
73+
//fn is_float(ty: &Ty) -> bool {
74+
// false
75+
//}
76+
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
77+
if is_ptr_or_ref(ty) {
78+
return activity == DiffActivity::Dual
79+
|| activity == DiffActivity::DualOnly
80+
|| activity == DiffActivity::Duplicated
81+
|| activity == DiffActivity::DuplicatedOnly
82+
|| activity == DiffActivity::Const;
83+
}
84+
true
85+
//if is_scalar_ty(&ty) {
86+
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
87+
// activity == DiffActivity::Const;
88+
//}
89+
}
90+
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
91+
return match mode {
92+
DiffMode::Inactive => false,
93+
DiffMode::Source => false,
94+
DiffMode::Forward | DiffMode::ForwardFirst => {
95+
// These are the only valid cases
96+
activity == DiffActivity::Dual
97+
|| activity == DiffActivity::DualOnly
98+
|| activity == DiffActivity::Const
99+
}
100+
DiffMode::Reverse | DiffMode::ReverseFirst => {
101+
// These are the only valid cases
102+
activity == DiffActivity::Active
103+
|| activity == DiffActivity::ActiveOnly
104+
|| activity == DiffActivity::Const
105+
|| activity == DiffActivity::Duplicated
106+
|| activity == DiffActivity::DuplicatedOnly
107+
}
108+
};
109+
}
110+
pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option<usize> {
111+
for i in 0..activity_vec.len() {
112+
if !valid_input_activity(mode, activity_vec[i]) {
113+
return Some(i);
114+
}
115+
}
116+
None
117+
}
118+
119+
#[allow(dead_code)]
120+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
121+
pub enum DiffActivity {
122+
None,
123+
Const,
124+
Active,
125+
ActiveOnly,
126+
Dual,
127+
DualOnly,
128+
Duplicated,
129+
DuplicatedOnly,
130+
FakeActivitySize,
131+
}
132+
133+
impl Display for DiffActivity {
134+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
135+
match self {
136+
DiffActivity::None => write!(f, "None"),
137+
DiffActivity::Const => write!(f, "Const"),
138+
DiffActivity::Active => write!(f, "Active"),
139+
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
140+
DiffActivity::Dual => write!(f, "Dual"),
141+
DiffActivity::DualOnly => write!(f, "DualOnly"),
142+
DiffActivity::Duplicated => write!(f, "Duplicated"),
143+
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
144+
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
145+
}
146+
}
147+
}
148+
149+
impl FromStr for DiffMode {
150+
type Err = ();
151+
152+
fn from_str(s: &str) -> Result<DiffMode, ()> {
153+
match s {
154+
"Inactive" => Ok(DiffMode::Inactive),
155+
"Source" => Ok(DiffMode::Source),
156+
"Forward" => Ok(DiffMode::Forward),
157+
"Reverse" => Ok(DiffMode::Reverse),
158+
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
159+
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
160+
_ => Err(()),
161+
}
162+
}
163+
}
164+
impl FromStr for DiffActivity {
165+
type Err = ();
166+
167+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
168+
match s {
169+
"None" => Ok(DiffActivity::None),
170+
"Active" => Ok(DiffActivity::Active),
171+
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
172+
"Const" => Ok(DiffActivity::Const),
173+
"Dual" => Ok(DiffActivity::Dual),
174+
"DualOnly" => Ok(DiffActivity::DualOnly),
175+
"Duplicated" => Ok(DiffActivity::Duplicated),
176+
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
177+
_ => Err(()),
178+
}
179+
}
180+
}
181+
182+
#[allow(dead_code)]
183+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
184+
pub struct AutoDiffAttrs {
185+
pub mode: DiffMode,
186+
pub ret_activity: DiffActivity,
187+
pub input_activity: Vec<DiffActivity>,
188+
}
189+
190+
impl AutoDiffAttrs {
191+
pub fn has_ret_activity(&self) -> bool {
192+
match self.ret_activity {
193+
DiffActivity::None => false,
194+
_ => true,
195+
}
196+
}
197+
pub fn has_active_only_ret(&self) -> bool {
198+
match self.ret_activity {
199+
DiffActivity::ActiveOnly => true,
200+
_ => false,
201+
}
202+
}
203+
}
204+
205+
impl AutoDiffAttrs {
206+
pub fn inactive() -> Self {
207+
AutoDiffAttrs {
208+
mode: DiffMode::Inactive,
209+
ret_activity: DiffActivity::None,
210+
input_activity: Vec::new(),
211+
}
212+
}
213+
pub fn source() -> Self {
214+
AutoDiffAttrs {
215+
mode: DiffMode::Source,
216+
ret_activity: DiffActivity::None,
217+
input_activity: Vec::new(),
218+
}
219+
}
220+
221+
pub fn is_active(&self) -> bool {
222+
match self.mode {
223+
DiffMode::Inactive => false,
224+
_ => true,
225+
}
226+
}
227+
228+
pub fn is_source(&self) -> bool {
229+
match self.mode {
230+
DiffMode::Source => true,
231+
_ => false,
232+
}
233+
}
234+
pub fn apply_autodiff(&self) -> bool {
235+
match self.mode {
236+
DiffMode::Inactive => false,
237+
DiffMode::Source => false,
238+
_ => true,
239+
}
240+
}
241+
242+
pub fn into_item(
243+
self,
244+
source: String,
245+
target: String,
246+
inputs: Vec<TypeTree>,
247+
output: TypeTree,
248+
) -> AutoDiffItem {
249+
AutoDiffItem { source, target, inputs, output, attrs: self }
250+
}
251+
}
252+
253+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
254+
pub struct AutoDiffItem {
255+
pub source: String,
256+
pub target: String,
257+
pub attrs: AutoDiffAttrs,
258+
pub inputs: Vec<TypeTree>,
259+
pub output: TypeTree,
260+
}
261+
262+
impl fmt::Display for AutoDiffItem {
263+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264+
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
265+
write!(f, " with attributes: {:?}", self.attrs)?;
266+
write!(f, " with inputs: {:?}", self.inputs)?;
267+
write!(f, " with output: {:?}", self.output)
268+
}
269+
}

compiler/rustc_ast/src/expand/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
77
use crate::MetaItem;
88

99
pub mod allocator;
10+
pub mod autodiff_attrs;
11+
pub mod typetree;
1012

1113
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1214
pub struct StrippedCfgItem<ModId = DefId> {
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use std::fmt;
2+
3+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
4+
5+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
6+
pub enum Kind {
7+
Anything,
8+
Integer,
9+
Pointer,
10+
Half,
11+
Float,
12+
Double,
13+
Unknown,
14+
}
15+
16+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
17+
pub struct TypeTree(pub Vec<Type>);
18+
19+
impl TypeTree {
20+
pub fn new() -> Self {
21+
Self(Vec::new())
22+
}
23+
pub fn all_ints() -> Self {
24+
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
25+
}
26+
pub fn int(size: usize) -> Self {
27+
let mut ints = Vec::with_capacity(size);
28+
for i in 0..size {
29+
ints.push(Type {
30+
offset: i as isize,
31+
size: 1,
32+
kind: Kind::Integer,
33+
child: TypeTree::new(),
34+
});
35+
}
36+
Self(ints)
37+
}
38+
}
39+
40+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
41+
pub struct FncTree {
42+
pub args: Vec<TypeTree>,
43+
pub ret: TypeTree,
44+
}
45+
46+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
47+
pub struct Type {
48+
pub offset: isize,
49+
pub size: usize,
50+
pub kind: Kind,
51+
pub child: TypeTree,
52+
}
53+
54+
impl Type {
55+
pub fn add_offset(self, add: isize) -> Self {
56+
let offset = match self.offset {
57+
-1 => add,
58+
x => add + x,
59+
};
60+
61+
Self { size: self.size, kind: self.kind, child: self.child, offset }
62+
}
63+
}
64+
65+
impl fmt::Display for Type {
66+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67+
<Self as fmt::Debug>::fmt(self, f)
68+
}
69+
}

compiler/rustc_builtin_macros/Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
33
version = "0.0.0"
44
edition = "2021"
55

6+
7+
[lints.rust]
8+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
9+
610
[lib]
711
doctest = false
812

0 commit comments

Comments
 (0)