@@ -586,6 +586,42 @@ fn thin_lto(
586
586
}
587
587
}
588
588
589
+ fn enable_autodiff_settings ( ad : & [ config:: AutoDiff ] , module : & mut ModuleCodegen < ModuleLlvm > ) {
590
+ for & val in ad {
591
+ match val {
592
+ config:: AutoDiff :: PrintModBefore => {
593
+ unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
594
+ }
595
+ config:: AutoDiff :: PrintPerf => {
596
+ llvm:: set_print_perf ( true ) ;
597
+ }
598
+ config:: AutoDiff :: PrintAA => {
599
+ llvm:: set_print_activity ( true ) ;
600
+ }
601
+ config:: AutoDiff :: PrintTA => {
602
+ llvm:: set_print_type ( true ) ;
603
+ }
604
+ config:: AutoDiff :: Inline => {
605
+ llvm:: set_inline ( true ) ;
606
+ }
607
+ config:: AutoDiff :: LooseTypes => {
608
+ llvm:: set_loose_types ( false ) ;
609
+ }
610
+ config:: AutoDiff :: PrintSteps => {
611
+ llvm:: set_print ( true ) ;
612
+ }
613
+ // We handle this below
614
+ config:: AutoDiff :: PrintModAfter => { }
615
+ // This is required and already checked
616
+ config:: AutoDiff :: Enable => { }
617
+ }
618
+ }
619
+ // This helps with handling enums for now.
620
+ llvm:: set_strict_aliasing ( false ) ;
621
+ // FIXME(ZuseZ4): Test this, since it was added a long time ago.
622
+ llvm:: set_rust_rules ( true ) ;
623
+ }
624
+
589
625
pub ( crate ) fn run_pass_manager (
590
626
cgcx : & CodegenContext < LlvmCodegenBackend > ,
591
627
dcx : DiagCtxtHandle < ' _ > ,
@@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
604
640
let opt_stage = if thin { llvm:: OptStage :: ThinLTO } else { llvm:: OptStage :: FatLTO } ;
605
641
let opt_level = config. opt_level . unwrap_or ( config:: OptLevel :: No ) ;
606
642
607
- // If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608
- // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609
- debug ! ( "running llvm pm opt pipeline" ) ;
643
+ // The PostAD behavior is the same that we would have if no autodiff was used.
644
+ // It will run the default optimization pipeline. If AD is enabled we select
645
+ // the DuringAD stage, which will disable vectorization and loop unrolling, and
646
+ // schedule two autodiff optimization + differentiation passes.
647
+ // We then run the llvm_optimize function a second time, to optimize the code which we generated
648
+ // in the enzyme differentiation pass.
649
+ let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
650
+ let stage =
651
+ if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD } ;
652
+
653
+ if enable_ad {
654
+ enable_autodiff_settings ( & config. autodiff , module) ;
655
+ }
656
+
610
657
unsafe {
611
- write:: llvm_optimize (
612
- cgcx,
613
- dcx,
614
- module,
615
- config,
616
- opt_level,
617
- opt_stage,
618
- write:: AutodiffStage :: DuringAD ,
619
- ) ?;
658
+ write:: llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, stage) ?;
620
659
}
621
- // FIXME(ZuseZ4): Make this more granular
622
- if cfg ! ( llvm_enzyme) && !thin {
660
+
661
+ if cfg ! ( llvm_enzyme) && enable_ad {
662
+ let opt_stage = llvm:: OptStage :: FatLTO ;
663
+ let stage = write:: AutodiffStage :: PostAD ;
623
664
unsafe {
624
- write:: llvm_optimize (
625
- cgcx,
626
- dcx,
627
- module,
628
- config,
629
- opt_level,
630
- llvm:: OptStage :: FatLTO ,
631
- write:: AutodiffStage :: PostAD ,
632
- ) ?;
665
+ write:: llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, stage) ?;
666
+ }
667
+
668
+ // This is the final IR, so people should be able to inspect the optimized autodiff output.
669
+ if config. autodiff . contains ( & config:: AutoDiff :: PrintModAfter ) {
670
+ unsafe { llvm:: LLVMDumpModule ( module. module_llvm . llmod ( ) ) } ;
633
671
}
634
672
}
673
+
635
674
debug ! ( "lto done" ) ;
636
675
Ok ( ( ) )
637
676
}
0 commit comments