Skip to content

Commit 273c2b5

Browse files
committed
use propper rustc error handler for type/act check
1 parent f8c263b commit 273c2b5

File tree

4 files changed

+41
-24
lines changed

4 files changed

+41
-24
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::str::FromStr;
33
use thin_vec::ThinVec;
44
use std::fmt::{Display, Formatter};
55
use crate::NestedMetaItem;
6+
use crate::ptr::P;
7+
use crate::{Ty, TyKind};
68

79
#[allow(dead_code)]
810
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -40,7 +42,29 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
4042
}
4143
}
4244
}
43-
45+
fn is_ptr_or_ref(ty: &Ty) -> bool {
46+
match ty.kind {
47+
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
48+
_ => false,
49+
}
50+
}
51+
// TODO We should make this more robust to also
52+
// accept aliases of f32 and f64
53+
//fn is_float(ty: &Ty) -> bool {
54+
// false
55+
//}
56+
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
57+
if is_ptr_or_ref(ty) {
58+
return activity == DiffActivity::Dual || activity == DiffActivity::DualOnly ||
59+
activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly ||
60+
activity == DiffActivity::Const;
61+
}
62+
true
63+
//if is_scalar_ty(&ty) {
64+
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
65+
// activity == DiffActivity::Const;
66+
//}
67+
}
4468
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
4569
return match mode {
4670
DiffMode::Inactive => false,

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ builtin_macros_alloc_must_statics = allocators must be statics
44
builtin_macros_autodiff = autodiff must be applied to function
55
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
66
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
7+
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
78
89
builtin_macros_asm_clobber_abi = clobber_abi
910
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//use crate::util::check_autodiff;
44

55
use crate::errors;
6-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity};
6+
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity};
77
use rustc_ast::ptr::P;
88
use rustc_ast::token::{Token, TokenKind};
99
use rustc_ast::tokenstream::*;
@@ -175,25 +175,6 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
175175
ty
176176
}
177177

178-
// TODO We should make this more robust to also
179-
// accept aliases of f32 and f64
180-
#[cfg(llvm_enzyme)]
181-
fn is_float(ty: &ast::Ty) -> bool {
182-
match ty.kind {
183-
TyKind::Path(_, ref path) => {
184-
let last = path.segments.last().unwrap();
185-
last.ident.name == sym::f32 || last.ident.name == sym::f64
186-
}
187-
_ => false,
188-
}
189-
}
190-
#[cfg(llvm_enzyme)]
191-
fn is_ptr_or_ref(ty: &ast::Ty) -> bool {
192-
match ty.kind {
193-
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
194-
_ => false,
195-
}
196-
}
197178

198179
// The body of our generated functions will consist of two black_Box calls.
199180
// The first will call the primal function with the original arguments.
@@ -277,7 +258,7 @@ fn gen_primal_call(
277258
// zero-initialized by Enzyme). Active arguments are not handled yet.
278259
// Each argument of the primal function (and the return type if existing) must be annotated with an
279260
// activity.
280-
#[cfg(llvm_enzyme)]
261+
//#[cfg(llvm_enzyme)]
281262
fn gen_enzyme_decl(
282263
ecx: &ExtCtxt<'_>,
283264
sig: &ast::FnSig,
@@ -301,13 +282,17 @@ fn gen_enzyme_decl(
301282
act: activity.to_string()
302283
});
303284
}
285+
if !valid_ty_for_activity(&arg.ty, *activity) {
286+
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidTypeForActivity {
287+
span: arg.ty.span,
288+
act: activity.to_string()
289+
});
290+
}
304291
match activity {
305292
DiffActivity::Active => {
306-
assert!(is_float(&arg.ty));
307293
act_ret.push(arg.ty.clone());
308294
}
309295
DiffActivity::Duplicated => {
310-
assert!(is_ptr_or_ref(&arg.ty));
311296
let mut shadow_arg = arg.clone();
312297
// We += into the shadow in reverse mode.
313298
shadow_arg.ty = P(assure_mut_ref(&arg.ty));

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ pub(crate) struct AllocMustStatics {
163163
#[primary_span]
164164
pub(crate) span: Span,
165165
}
166+
#[derive(Diagnostic)]
167+
#[diag(builtin_macros_autodiff_ty_activity)]
168+
pub(crate) struct AutoDiffInvalidTypeForActivity {
169+
#[primary_span]
170+
pub(crate) span: Span,
171+
pub(crate) act: String,
172+
}
166173

167174
#[derive(Diagnostic)]
168175
#[diag(builtin_macros_autodiff_mode_activity)]

0 commit comments

Comments
 (0)