diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ad4bd05a08b46..248b01a0a3374 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -69,7 +69,7 @@ use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{DiagCtxt, FatalError, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; use rustc_middle::ty::TyCtxt; -use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; +use rustc_session::config::{self, AutoDiff, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; use rustc_session::Session; use rustc_span::symbol::sym; use rustc_span::InnerSpan; @@ -707,7 +707,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> { unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, - llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) { + llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize], ad: &[AutoDiff]) { // first, remove all calls from fnc let bb = LLVMGetFirstBasicBlock(tgt); @@ -729,12 +729,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool, let last_inst = LLVMRustGetLastInstruction(bb).unwrap(); LLVMPositionBuilderAtEnd(builder, bb); - let safety_run_checks; - if std::env::var("ENZYME_NO_SAFETY_CHECKS").is_ok() { - safety_run_checks = false; - } else { - safety_run_checks = true; - } + let safety_run_checks = !ad.contains(&AutoDiff::NoSafetyChecks); if inner_param_num == outer_param_num { call_args = outer_args; @@ -951,6 +946,7 @@ pub(crate) unsafe fn enzyme_ad( diag_handler: &DiagCtxt, item: AutoDiffItem, logic_ref: EnzymeLogicRef, + ad: &[AutoDiff], ) -> Result<(), FatalError> { let autodiff_mode = item.attrs.mode; let rust_name = item.source; @@ -1010,16 +1006,16 @@ pub(crate) unsafe fn enzyme_ad( llvm::set_strict_aliasing(false); - if std::env::var("ENZYME_PRINT_TA").is_ok() { + if ad.contains(&AutoDiff::PrintTA) { llvm::set_print_type(true); } - if std::env::var("ENZYME_PRINT_AA").is_ok() { - llvm::set_print_activity(true); + if ad.contains(&AutoDiff::PrintTA) { + llvm::set_print_type(true); } - if std::env::var("ENZYME_PRINT_PERF").is_ok() { + if ad.contains(&AutoDiff::PrintPerf) { llvm::set_print_perf(true); } - if std::env::var("ENZYME_PRINT").is_ok() { + if ad.contains(&AutoDiff::Print) { llvm::set_print(true); } @@ -1062,7 +1058,7 @@ pub(crate) unsafe fn enzyme_ad( let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); let rev_mode = item.attrs.mode == DiffMode::Reverse; - create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions); + create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions, ad); // TODO: implement drop for wrapper type? FreeTypeAnalysis(type_analysis); @@ -1087,7 +1083,9 @@ pub(crate) unsafe fn differentiate( llvm::set_strict_aliasing(false); - if std::env::var("ENZYME_LOOSE_TYPES").is_ok() { + let ad = &config.autodiff; + + if ad.contains(&AutoDiff::LooseTypes) { dbg!("Setting loose types to true"); llvm::set_loose_types(true); } @@ -1110,41 +1108,42 @@ pub(crate) unsafe fn differentiate( // trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary. // This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in // Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions? - if std::env::var("ENZYME_OPT").is_ok() { + if ad.contains(&AutoDiff::OPT) { dbg!("Enable extra debug helper to debug Enzyme through the opt plugin"); crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i); } } - if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() || std::env::var("ENZYME_OPT").is_ok(){ + if ad.contains(&AutoDiff::PrintModBefore) || ad.contains(&AutoDiff::OPT) { unsafe { LLVMDumpModule(llmod); } } - if std::env::var("ENZYME_INLINE").is_ok() { + if ad.contains(&AutoDiff::Inline) { dbg!("Setting inline to true"); llvm::set_inline(true); } - if std::env::var("ENZYME_TT_DEPTH").is_ok() { - let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); - let depth = depth.parse::().unwrap(); - assert!(depth >= 1); - llvm::set_max_int_offset(depth); - } - if std::env::var("ENZYME_TT_WIDTH").is_ok() { - let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); - let width = width.parse::().unwrap(); - assert!(width >= 1); - llvm::set_max_type_offset(width); - } - - if std::env::var("ENZYME_RUNTIME_ACTIVITY").is_ok() { + if ad.contains(&AutoDiff::RuntimeActivity) { dbg!("Setting runtime activity check to true"); llvm::set_runtime_activity_check(true); } + for val in ad { + match &val { + AutoDiff::TTDepth(depth) => { + assert!(*depth >= 1); + llvm::set_max_int_offset(*depth); + } + AutoDiff::TTWidth(width) => { + assert!(*width >= 1); + llvm::set_max_type_offset(*width); + } + _ => {}, + } + }; + let differentiate = !diff_items.is_empty(); let mut first_order_items: Vec = vec![]; let mut higher_order_items: Vec = vec![]; @@ -1157,29 +1156,29 @@ pub(crate) unsafe fn differentiate( } } - let mut fnc_opt = false; - if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() { - dbg!("Enable extra optimizations for Enzyme"); - fnc_opt = true; - } + + let fnc_opt = ad.contains(&AutoDiff::EnableFncOpt); // If a function is a base for some higher order ad, always optimize let fnc_opt_base = true; let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8); for item in first_order_items { - let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt); + let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt, ad); assert!(res.is_ok()); } // For the rest, follow the user choice on debug vs release. // Reuse the opt one if possible for better compile time (Enzyme internal caching). let logic_ref = match fnc_opt { - true => logic_ref_opt, + true => { + dbg!("Enable extra optimizations for Enzyme"); + logic_ref_opt + } false => CreateEnzymeLogic(fnc_opt as u8), }; for item in higher_order_items { - let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref); + let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref, ad); assert!(res.is_ok()); } @@ -1212,14 +1211,14 @@ pub(crate) unsafe fn differentiate( break; } } - if std::env::var("ENZYME_PRINT_MOD_AFTER_ENZYME").is_ok() { + if ad.contains(&AutoDiff::PrintModAfterEnzyme) { unsafe { LLVMDumpModule(llmod); } } - if std::env::var("ENZYME_NO_MOD_OPT_AFTER").is_ok() || !differentiate { + if ad.contains(&AutoDiff::NoModOptAfter) || !differentiate { trace!("Skipping module optimization after automatic differentiation"); } else { if let Some(opt_level) = config.opt_level { @@ -1231,18 +1230,18 @@ pub(crate) unsafe fn differentiate( }; let mut first_run = false; dbg!("Running Module Optimization after differentiation"); - if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() { + if ad.contains(&AutoDiff::NoVecUnroll) { // disables vectorization and loop unrolling first_run = true; } - if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + if ad.contains(&AutoDiff::AltPipeline) { dbg!("Running first postAD optimization"); first_run = true; } let noop = false; llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?; } - if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { + if ad.contains(&AutoDiff::AltPipeline) { dbg!("Running Second postAD optimization"); if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { @@ -1253,7 +1252,7 @@ pub(crate) unsafe fn differentiate( }; let mut first_run = false; dbg!("Running Module Optimization after differentiation"); - if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() { + if ad.contains(&AutoDiff::NoVecUnroll) { // enables vectorization and loop unrolling first_run = false; } @@ -1263,7 +1262,7 @@ pub(crate) unsafe fn differentiate( } } - if std::env::var("ENZYME_PRINT_MOD_AFTER_OPTS").is_ok() { + if ad.contains(&AutoDiff::PrintModAfterOpts) { unsafe { LLVMDumpModule(llmod); } @@ -1341,15 +1340,16 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; + // Second run only relevant for AD let first_run = true; - let noop; - if std::env::var("ENZYME_ALT_PIPELINE").is_ok() { - noop = true; - dbg!("Skipping PreAD optimization"); - } else { - noop = false; - } + let noop = false; + //if ad.contains(&AutoDiff::AltPipeline) { + // noop = true; + // dbg!("Skipping PreAD optimization"); + //} else { + // noop = false; + //} return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop); } Ok(()) diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index fda8330ea8f11..f2d57060ed13d 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -118,6 +118,7 @@ pub struct ModuleConfig { pub inline_threshold: Option, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub autodiff: Vec, } impl ModuleConfig { @@ -259,6 +260,7 @@ impl ModuleConfig { inline_threshold: sess.opts.cg.inline_threshold, emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), + autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 04a7714d4137e..6e27eafd63f59 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -729,6 +729,7 @@ fn test_unstable_options_tracking_hash() { // Make sure that changing a [TRACKED] option changes the hash. // tidy-alphabetical-start + tracked!(autodiff, vec![String::from("ad_flags")]); tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(asm_comments, true); diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 2219fd5e951a8..1ae563fa3222c 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -174,6 +174,53 @@ pub enum InstrumentCoverage { Off, } +/// The different settings that the `-Z ad` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum AutoDiff { + /// Print TypeAnalysis information + PrintTA, + /// Print ActivityAnalysis Information + PrintAA, + /// Print Performance Warnings from Enzyme + PrintPerf, + /// Combines the three print flags above. + Print, + /// Print the whole module, before running opts. + PrintModBefore, + /// Print the whole module just before we pass it to Enzyme. + /// For Debug purpose, prefer the OPT flag below + PrintModAfterOpts, + /// Print the module after Enzyme differentiated everything. + PrintModAfterEnzyme, + + /// Enzyme's loose type debug helper (can cause incorrect gradients) + LooseTypes, + /// Output a Module using __enzyme calls to prepare it for opt + enzyme pass usage + OPT, + + /// TypeTree options + /// TODO: Figure out how to let users construct these, + /// or whether we want to leave this option in the first place. + TTWidth(u64), + TTDepth(u64), + + /// More flags + NoModOptAfter, + /// Tell Enzyme to run LLVM Opts on each function it generated. By default off, + /// since we already optimize the whole module after Enzyme is done. + EnableFncOpt, + NoVecUnroll, + /// Obviously unsafe, disable the length checks that we have for shadow args. + NoSafetyChecks, + RuntimeActivity, + /// Runs Enzyme specific Inlining + Inline, + /// Runs Optimization twice after AD, and zero times after. + /// This is mainly for Benchmarking purpose to show that + /// compiler based AD has a performance benefit. TODO: fix + AltPipeline, +} + /// Settings for `-Z instrument-xray` flag. #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct InstrumentXRay { @@ -3229,8 +3276,9 @@ pub(crate) mod dep_tracking { LinkerPluginLto, LocationDetail, LtoCli, NextSolverConfig, OomStrategy, OptLevel, OutFileName, OutputType, OutputTypes, Polonius, RemapPathScopeComponents, ResolveDocLinks, SourceFileHashAlgorithm, SplitDwarfKind, SwitchWithOptPath, SymbolManglingVersion, - TrimmedDefPaths, WasiExecModel, + TrimmedDefPaths, WasiExecModel, AutoDiff, }; + //use crate::config::AutoDiff; use crate::lint; use crate::utils::NativeLib; use rustc_data_structures::fx::FxIndexMap; @@ -3285,6 +3333,7 @@ pub(crate) mod dep_tracking { } impl_dep_tracking_hash_via_hash!( + AutoDiff, bool, usize, NonZeroUsize, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 10a4bdb94d46f..b18c20f47cbe9 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -358,6 +358,7 @@ fn build_options( #[allow(non_upper_case_globals)] mod desc { + pub const parse_autodiff: &str = "various values"; pub const parse_no_flag: &str = "no value"; pub const parse_bool: &str = "one of: `y`, `yes`, `on`, `true`, `n`, `no`, `off` or `false`"; pub const parse_opt_bool: &str = parse_bool; @@ -917,6 +918,42 @@ mod parse { } } + pub(crate) fn parse_autodiff( + slot: &mut Vec, + v: Option<&str>, + ) -> bool { + + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "PrintTA" => AutoDiff::PrintTA, + "PrintAA" => AutoDiff::PrintAA, + "PrintPerf" => AutoDiff::PrintPerf, + "Print" => AutoDiff::Print, + "PrintModBefore" => AutoDiff::PrintModBefore, + "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts, + "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme, + "LooseTypes" => AutoDiff::LooseTypes, + "OPT" => AutoDiff::OPT, + "NoModOptAfter" => AutoDiff::NoModOptAfter, + "EnableFncOpt" => AutoDiff::EnableFncOpt, + "NoVecUnroll" => AutoDiff::NoVecUnroll, + "NoSafetyChecks" => AutoDiff::NoSafetyChecks, + "Inline" => AutoDiff::Inline, + "AltPipeline" => AutoDiff::AltPipeline, + _ => return false, + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_instrument_coverage( slot: &mut InstrumentCoverage, v: Option<&str>, @@ -1544,6 +1581,8 @@ options! { either `loaded` or `not-loaded`."), assume_incomplete_release: bool = (false, parse_bool, [TRACKED], "make cfg(version) treat the current version as incomplete (default: no)"), + autodiff: Vec = (Vec::new(), parse_autodiff, [TRACKED], + "a list autodiff flags to enable (comma separated)"), #[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")] binary_dep_depinfo: bool = (false, parse_bool, [TRACKED], "include artifacts (sysroot, crate dependencies) used during compilation in dep-info \