@@ -584,12 +584,10 @@ fn thin_lto(
584
584
}
585
585
}
586
586
587
- fn enable_autodiff_settings ( ad : & [ config:: AutoDiff ] , module : & mut ModuleCodegen < ModuleLlvm > ) {
587
+ fn enable_autodiff_settings ( ad : & [ config:: AutoDiff ] ) {
588
588
for & val in ad {
589
+ // We intentionally don't use a wildcard, to not forget handling anything new.
589
590
match val {
590
- config:: AutoDiff :: PrintModBefore => {
591
- unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
592
- }
593
591
config:: AutoDiff :: PrintPerf => {
594
592
llvm:: set_print_perf ( true ) ;
595
593
}
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
603
601
llvm:: set_inline ( true ) ;
604
602
}
605
603
config:: AutoDiff :: LooseTypes => {
606
- llvm:: set_loose_types ( false ) ;
604
+ llvm:: set_loose_types ( true ) ;
607
605
}
608
606
config:: AutoDiff :: PrintSteps => {
609
607
llvm:: set_print ( true ) ;
610
608
}
611
- // We handle this below
609
+ // We handle this in the PassWrapper.cpp
610
+ config:: AutoDiff :: PrintPasses => { }
611
+ // We handle this in the PassWrapper.cpp
612
+ config:: AutoDiff :: PrintModBefore => { }
613
+ // We handle this in the PassWrapper.cpp
612
614
config:: AutoDiff :: PrintModAfter => { }
613
- // We handle this below
615
+ // We handle this in the PassWrapper.cpp
614
616
config:: AutoDiff :: PrintModFinal => { }
615
617
// This is required and already checked
616
618
config:: AutoDiff :: Enable => { }
619
+ // We handle this below
620
+ config:: AutoDiff :: NoPostopt => { }
617
621
}
618
622
}
619
623
// This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
647
651
// We then run the llvm_optimize function a second time, to optimize the code which we generated
648
652
// in the enzyme differentiation pass.
649
653
let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
650
- let stage =
651
- if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD } ;
654
+ let stage = if thin {
655
+ write:: AutodiffStage :: PreAD
656
+ } else {
657
+ if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD }
658
+ } ;
652
659
653
660
if enable_ad {
654
- enable_autodiff_settings ( & config. autodiff , module ) ;
661
+ enable_autodiff_settings ( & config. autodiff ) ;
655
662
}
656
663
657
664
unsafe {
658
665
write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
659
666
}
660
667
661
- if cfg ! ( llvm_enzyme) && enable_ad {
662
- // This is the post-autodiff IR, mainly used for testing and educational purposes.
663
- if config. autodiff . contains ( & config:: AutoDiff :: PrintModAfter ) {
664
- unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
665
- }
666
-
668
+ if cfg ! ( llvm_enzyme) && enable_ad && !thin {
667
669
let opt_stage = llvm:: OptStage :: FatLTO ;
668
670
let stage = write:: AutodiffStage :: PostAD ;
669
- unsafe {
670
- write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
671
+ if !config. autodiff . contains ( & config:: AutoDiff :: NoPostopt ) {
672
+ unsafe {
673
+ write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
674
+ }
671
675
}
672
676
673
677
// This is the final IR, so people should be able to inspect the optimized autodiff output,
0 commit comments