@@ -4,10 +4,9 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
6
use rustc_errors:: FatalError ;
7
- use rustc_session:: config:: Lto ;
8
7
use tracing:: { debug, trace} ;
9
8
10
- use crate :: back:: write:: { llvm_err, llvm_optimize } ;
9
+ use crate :: back:: write:: llvm_err;
11
10
use crate :: builder:: SBuilder ;
12
11
use crate :: context:: SimpleCx ;
13
12
use crate :: declare:: declare_simple_fn;
@@ -53,8 +52,6 @@ fn generate_enzyme_call<'ll>(
53
52
let mut ad_name: String = match attrs. mode {
54
53
DiffMode :: Forward => "__enzyme_fwddiff" ,
55
54
DiffMode :: Reverse => "__enzyme_autodiff" ,
56
- DiffMode :: ForwardFirst => "__enzyme_fwddiff" ,
57
- DiffMode :: ReverseFirst => "__enzyme_autodiff" ,
58
55
_ => panic ! ( "logic bug in autodiff, unrecognized mode" ) ,
59
56
}
60
57
. to_string ( ) ;
@@ -153,7 +150,7 @@ fn generate_enzyme_call<'ll>(
153
150
_ => { }
154
151
}
155
152
156
- trace ! ( "matching autodiff arguments" ) ;
153
+ debug ! ( "matching autodiff arguments" ) ;
157
154
// We now handle the issue that Rust level arguments not always match the llvm-ir level
158
155
// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
159
156
// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
@@ -164,10 +161,10 @@ fn generate_enzyme_call<'ll>(
164
161
let mut activity_pos = 0 ;
165
162
let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn) ;
166
163
while activity_pos < inputs. len ( ) {
167
- let activity = inputs[ activity_pos as usize ] ;
164
+ let diff_activity = inputs[ activity_pos as usize ] ;
168
165
// Duplicated arguments received a shadow argument, into which enzyme will write the
169
166
// gradient.
170
- let ( activity, duplicated) : ( & Metadata , bool ) = match activity {
167
+ let ( activity, duplicated) : ( & Metadata , bool ) = match diff_activity {
171
168
DiffActivity :: None => panic ! ( "not a valid input activity" ) ,
172
169
DiffActivity :: Const => ( enzyme_const, false ) ,
173
170
DiffActivity :: Active => ( enzyme_out, false ) ,
@@ -222,7 +219,15 @@ fn generate_enzyme_call<'ll>(
222
219
// A duplicated pointer will have the following two outer_fn arguments:
223
220
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
224
221
// (..., metadata! enzyme_dup, ptr, ptr, ...).
225
- assert ! ( llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer ) ;
222
+ if matches ! (
223
+ diff_activity,
224
+ DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly
225
+ ) {
226
+ assert ! (
227
+ llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Pointer
228
+ ) ;
229
+ }
230
+ // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
226
231
args. push ( next_outer_arg) ;
227
232
outer_pos += 2 ;
228
233
activity_pos += 1 ;
@@ -277,7 +282,7 @@ pub(crate) fn differentiate<'ll>(
277
282
module : & ' ll ModuleCodegen < ModuleLlvm > ,
278
283
cgcx : & CodegenContext < LlvmCodegenBackend > ,
279
284
diff_items : Vec < AutoDiffItem > ,
280
- config : & ModuleConfig ,
285
+ _config : & ModuleConfig ,
281
286
) -> Result < ( ) , FatalError > {
282
287
for item in & diff_items {
283
288
trace ! ( "{}" , item) ;
@@ -318,29 +323,6 @@ pub(crate) fn differentiate<'ll>(
318
323
319
324
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
320
325
321
- if let Some ( opt_level) = config. opt_level {
322
- let opt_stage = match cgcx. lto {
323
- Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
324
- Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
325
- _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
326
- _ => llvm:: OptStage :: PreLinkNoLTO ,
327
- } ;
328
- // This is our second opt call, so now we run all opts,
329
- // to make sure we get the best performance.
330
- let skip_size_increasing_opts = false ;
331
- trace ! ( "running Module Optimization after differentiation" ) ;
332
- unsafe {
333
- llvm_optimize (
334
- cgcx,
335
- diag_handler. handle ( ) ,
336
- module,
337
- config,
338
- opt_level,
339
- opt_stage,
340
- skip_size_increasing_opts,
341
- ) ?
342
- } ;
343
- }
344
326
trace ! ( "done with differentiate()" ) ;
345
327
346
328
Ok ( ( ) )
0 commit comments