@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
4
4
use std:: sync:: Arc ;
5
5
use std:: { fs, slice, str} ;
6
6
7
- use libc:: { c_char, c_int, c_void, size_t} ;
7
+ use libc:: { c_char, c_int, c_uint , c_void, size_t} ;
8
8
use llvm:: {
9
9
LLVMRustLLVMHasZlibCompressionForDebugSymbols , LLVMRustLLVMHasZstdCompressionForDebugSymbols ,
10
10
} ;
11
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
11
12
use rustc_codegen_ssa:: back:: link:: ensure_removed;
12
13
use rustc_codegen_ssa:: back:: write:: {
13
14
BitcodeSection , CodegenContext , EmitObj , ModuleConfig , TargetMachineFactoryConfig ,
@@ -27,7 +28,7 @@ use rustc_session::config::{
27
28
use rustc_span:: InnerSpan ;
28
29
use rustc_span:: symbol:: sym;
29
30
use rustc_target:: spec:: { CodeModel , RelocModel , SanitizerSet , SplitDebuginfo , TlsModel } ;
30
- use tracing:: debug;
31
+ use tracing:: { debug, trace } ;
31
32
32
33
use crate :: back:: lto:: ThinBuffer ;
33
34
use crate :: back:: owned_target_machine:: OwnedTargetMachine ;
@@ -39,7 +40,13 @@ use crate::errors::{
39
40
WithLlvmError , WriteBytecode ,
40
41
} ;
41
42
use crate :: llvm:: diagnostic:: OptimizationDiagnosticKind :: * ;
42
- use crate :: llvm:: { self , DiagnosticInfo , PassManager } ;
43
+ use crate :: llvm:: {
44
+ self , AttributeKind , DiagnosticInfo , LLVMCreateStringAttribute , LLVMGetFirstFunction ,
45
+ LLVMGetNextFunction , LLVMGetStringAttributeAtIndex , LLVMIsEnumAttribute , LLVMIsStringAttribute ,
46
+ LLVMRemoveStringAttributeAtIndex , LLVMRustAddEnumAttributeAtIndex ,
47
+ LLVMRustAddFunctionAttributes , LLVMRustGetEnumAttributeAtIndex ,
48
+ LLVMRustRemoveEnumAttributeAtIndex , PassManager ,
49
+ } ;
43
50
use crate :: type_:: Type ;
44
51
use crate :: { LlvmCodegenBackend , ModuleLlvm , base, common, llvm_util} ;
45
52
@@ -515,9 +522,34 @@ pub(crate) unsafe fn llvm_optimize(
515
522
config : & ModuleConfig ,
516
523
opt_level : config:: OptLevel ,
517
524
opt_stage : llvm:: OptStage ,
525
+ skip_size_increasing_opts : bool ,
518
526
) -> Result < ( ) , FatalError > {
519
- let unroll_loops =
520
- opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
527
+ // Enzyme:
528
+ // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
529
+ // source code. However, benchmarks show that optimizations increasing the code size
530
+ // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
531
+ // and finally re-optimize the module, now with all optimizations available.
532
+ // TODO: In a future update we could figure out how to only optimize functions getting
533
+ // differentiated.
534
+
535
+ let unroll_loops;
536
+ let vectorize_slp;
537
+ let vectorize_loop;
538
+
539
+ if skip_size_increasing_opts {
540
+ unroll_loops = false ;
541
+ vectorize_slp = false ;
542
+ vectorize_loop = false ;
543
+ } else {
544
+ unroll_loops =
545
+ opt_level != config:: OptLevel :: Size && opt_level != config:: OptLevel :: SizeMin ;
546
+ vectorize_slp = config. vectorize_slp ;
547
+ vectorize_loop = config. vectorize_loop ;
548
+ }
549
+ trace ! (
550
+ "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}" ,
551
+ unroll_loops, vectorize_slp, vectorize_loop
552
+ ) ;
521
553
let using_thin_buffers = opt_stage == llvm:: OptStage :: PreLinkThinLTO || config. bitcode_needed ( ) ;
522
554
let pgo_gen_path = get_pgo_gen_path ( config) ;
523
555
let pgo_use_path = get_pgo_use_path ( config) ;
@@ -581,8 +613,8 @@ pub(crate) unsafe fn llvm_optimize(
581
613
using_thin_buffers,
582
614
config. merge_functions ,
583
615
unroll_loops,
584
- config . vectorize_slp ,
585
- config . vectorize_loop ,
616
+ vectorize_slp,
617
+ vectorize_loop,
586
618
config. no_builtins ,
587
619
config. emit_lifetime_markers ,
588
620
sanitizer_options. as_ref ( ) ,
@@ -605,6 +637,113 @@ pub(crate) unsafe fn llvm_optimize(
605
637
result. into_result ( ) . map_err ( |( ) | llvm_err ( dcx, LlvmError :: RunLlvmPasses ) )
606
638
}
607
639
640
+ pub ( crate ) fn differentiate (
641
+ module : & ModuleCodegen < ModuleLlvm > ,
642
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
643
+ diff_items : Vec < AutoDiffItem > ,
644
+ config : & ModuleConfig ,
645
+ ) -> Result < ( ) , FatalError > {
646
+ for item in & diff_items {
647
+ trace ! ( "{}" , item) ;
648
+ }
649
+
650
+ let llmod = module. module_llvm . llmod ( ) ;
651
+ let llcx = & module. module_llvm . llcx ;
652
+ let diag_handler = cgcx. create_dcx ( ) ;
653
+
654
+ // Before dumping the module, we want all the tt to become part of the module.
655
+ for item in diff_items. iter ( ) {
656
+ let name = CString :: new ( item. source . clone ( ) ) . unwrap ( ) ;
657
+ let fn_def: Option < & llvm:: Value > =
658
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, name. as_ptr ( ) ) } ;
659
+ let fn_def = match fn_def {
660
+ Some ( x) => x,
661
+ None => {
662
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
663
+ src : item. source . clone ( ) ,
664
+ target : item. target . clone ( ) ,
665
+ error : "could not find source function" . to_owned ( ) ,
666
+ } ) ) ;
667
+ }
668
+ } ;
669
+ let tgt_name = CString :: new ( item. target . clone ( ) ) . unwrap ( ) ;
670
+ dbg ! ( "Target name: {:?}" , & tgt_name) ;
671
+ let fn_target: Option < & llvm:: Value > =
672
+ unsafe { llvm:: LLVMGetNamedFunction ( llmod, tgt_name. as_ptr ( ) ) } ;
673
+ let fn_target = match fn_target {
674
+ Some ( x) => x,
675
+ None => {
676
+ return Err ( llvm_err ( diag_handler. handle ( ) , LlvmError :: PrepareAutoDiff {
677
+ src : item. source . clone ( ) ,
678
+ target : item. target . clone ( ) ,
679
+ error : "could not find target function" . to_owned ( ) ,
680
+ } ) ) ;
681
+ }
682
+ } ;
683
+
684
+ crate :: builder:: add_opt_dbg_helper2 ( llmod, llcx, fn_def, fn_target, item. attrs . clone ( ) ) ;
685
+ }
686
+
687
+ // We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
688
+ // which Enzyme doesn't understand.
689
+ unsafe {
690
+ let mut f = LLVMGetFirstFunction ( llmod) ;
691
+ loop {
692
+ if let Some ( lf) = f {
693
+ f = LLVMGetNextFunction ( lf) ;
694
+ let myhwattr = "enzyme_hw" ;
695
+ let attr = LLVMGetStringAttributeAtIndex (
696
+ lf,
697
+ c_uint:: MAX ,
698
+ myhwattr. as_ptr ( ) as * const c_char ,
699
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
700
+ ) ;
701
+ if LLVMIsStringAttribute ( attr) {
702
+ LLVMRemoveStringAttributeAtIndex (
703
+ lf,
704
+ c_uint:: MAX ,
705
+ myhwattr. as_ptr ( ) as * const c_char ,
706
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
707
+ ) ;
708
+ } else {
709
+ LLVMRustRemoveEnumAttributeAtIndex (
710
+ lf,
711
+ c_uint:: MAX ,
712
+ AttributeKind :: SanitizeHWAddress ,
713
+ ) ;
714
+ }
715
+ } else {
716
+ break ;
717
+ }
718
+ }
719
+ }
720
+
721
+ if let Some ( opt_level) = config. opt_level {
722
+ let opt_stage = match cgcx. lto {
723
+ Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
724
+ Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
725
+ _ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
726
+ _ => llvm:: OptStage :: PreLinkNoLTO ,
727
+ } ;
728
+ let skip_size_increasing_opts = false ;
729
+ dbg ! ( "Running Module Optimization after differentiation" ) ;
730
+ unsafe {
731
+ llvm_optimize (
732
+ cgcx,
733
+ diag_handler. handle ( ) ,
734
+ module,
735
+ config,
736
+ opt_level,
737
+ opt_stage,
738
+ skip_size_increasing_opts,
739
+ ) ?
740
+ } ;
741
+ }
742
+ dbg ! ( "Done with differentiate()" ) ;
743
+
744
+ Ok ( ( ) )
745
+ }
746
+
608
747
// Unsafe due to LLVM calls.
609
748
pub ( crate ) unsafe fn optimize (
610
749
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -627,14 +766,68 @@ pub(crate) unsafe fn optimize(
627
766
unsafe { llvm:: LLVMWriteBitcodeToFile ( llmod, out. as_ptr ( ) ) } ;
628
767
}
629
768
769
+ // This code enables Enzyme to differentiate code containing Rust enums.
770
+ // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
771
+ // away the enums and allows Enzyme to understand why a value can be of different types in
772
+ // different code sections. We remove this attribute after Enzyme is done, to not affect the
773
+ // rest of the compilation.
774
+ #[ cfg( llvm_enzyme) ]
775
+ unsafe {
776
+ let mut f = LLVMGetFirstFunction ( llmod) ;
777
+ loop {
778
+ if let Some ( lf) = f {
779
+ f = LLVMGetNextFunction ( lf) ;
780
+ let myhwattr = "enzyme_hw" ;
781
+ let myhwv = "" ;
782
+ let prevattr = LLVMRustGetEnumAttributeAtIndex (
783
+ lf,
784
+ c_uint:: MAX ,
785
+ AttributeKind :: SanitizeHWAddress ,
786
+ ) ;
787
+ if LLVMIsEnumAttribute ( prevattr) {
788
+ let attr = LLVMCreateStringAttribute (
789
+ llcx,
790
+ myhwattr. as_ptr ( ) as * const c_char ,
791
+ myhwattr. as_bytes ( ) . len ( ) as c_uint ,
792
+ myhwv. as_ptr ( ) as * const c_char ,
793
+ myhwv. as_bytes ( ) . len ( ) as c_uint ,
794
+ ) ;
795
+ LLVMRustAddFunctionAttributes ( lf, c_uint:: MAX , & attr, 1 ) ;
796
+ } else {
797
+ LLVMRustAddEnumAttributeAtIndex (
798
+ llcx,
799
+ lf,
800
+ c_uint:: MAX ,
801
+ AttributeKind :: SanitizeHWAddress ,
802
+ ) ;
803
+ }
804
+ } else {
805
+ break ;
806
+ }
807
+ }
808
+ }
809
+
630
810
if let Some ( opt_level) = config. opt_level {
631
811
let opt_stage = match cgcx. lto {
632
812
Lto :: Fat => llvm:: OptStage :: PreLinkFatLTO ,
633
813
Lto :: Thin | Lto :: ThinLocal => llvm:: OptStage :: PreLinkThinLTO ,
634
814
_ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
635
815
_ => llvm:: OptStage :: PreLinkNoLTO ,
636
816
} ;
637
- return unsafe { llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage) } ;
817
+
818
+ // If we know that we will later run AD, then we disable vectorization and loop unrolling
819
+ let skip_size_increasing_opts = cfg ! ( llvm_enzyme) ;
820
+ return unsafe {
821
+ llvm_optimize (
822
+ cgcx,
823
+ dcx,
824
+ module,
825
+ config,
826
+ opt_level,
827
+ opt_stage,
828
+ skip_size_increasing_opts,
829
+ )
830
+ } ;
638
831
}
639
832
Ok ( ( ) )
640
833
}
0 commit comments