Skip to content

Commit 11c3bae

Browse files
committed
update rustc_codegen_llvm
1 parent c1db4dc commit 11c3bae

File tree

13 files changed

+631
-29
lines changed

13 files changed

+631
-29
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+3-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
109
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1110
use crate::ptr::P;
1211
use crate::{Ty, TyKind};
@@ -79,10 +78,6 @@ pub struct AutoDiffItem {
7978
/// The name of the function being generated
8079
pub target: String,
8180
pub attrs: AutoDiffAttrs,
82-
/// Describe the memory layout of input types
83-
pub inputs: Vec<TypeTree>,
84-
/// Describe the memory layout of the output type
85-
pub output: TypeTree,
8681
}
8782
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8883
pub struct AutoDiffAttrs {
@@ -262,22 +257,14 @@ impl AutoDiffAttrs {
262257
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263258
}
264259

265-
pub fn into_item(
266-
self,
267-
source: String,
268-
target: String,
269-
inputs: Vec<TypeTree>,
270-
output: TypeTree,
271-
) -> AutoDiffItem {
272-
AutoDiffItem { source, target, inputs, output, attrs: self }
260+
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
261+
AutoDiffItem { source, target, attrs: self }
273262
}
274263
}
275264

276265
impl fmt::Display for AutoDiffItem {
277266
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278267
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279-
write!(f, " with attributes: {:?}", self.attrs)?;
280-
write!(f, " with inputs: {:?}", self.inputs)?;
281-
write!(f, " with output: {:?}", self.output)
268+
write!(f, " with attributes: {:?}", self.attrs)
282269
}
283270
}

compiler/rustc_codegen_llvm/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
5151
codegen_llvm_run_passes = failed to run LLVM passes
5252
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
5353
54+
codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error}
55+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
56+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
57+
5458
codegen_llvm_sanitizer_memtag_requires_mte =
5559
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
5660

compiler/rustc_codegen_llvm/src/back/lto.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,11 @@ pub(crate) fn run_pass_manager(
616616
}
617617
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
618618
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
619-
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?;
619+
620+
// We will run this again with different values in the context of automatic differentiation.
621+
let first_run = true;
622+
debug!("running llvm pm opt pipeline");
623+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
620624
}
621625
debug!("lto done");
622626
Ok(())

compiler/rustc_codegen_llvm/src/back/write.rs

+201-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ use std::path::{Path, PathBuf};
44
use std::sync::Arc;
55
use std::{fs, slice, str};
66

7-
use libc::{c_char, c_int, c_void, size_t};
7+
use libc::{c_char, c_int, c_uint, c_void, size_t};
88
use llvm::{
99
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
1010
};
11+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1112
use rustc_codegen_ssa::back::link::ensure_removed;
1213
use rustc_codegen_ssa::back::write::{
1314
BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig,
@@ -27,7 +28,7 @@ use rustc_session::config::{
2728
use rustc_span::InnerSpan;
2829
use rustc_span::symbol::sym;
2930
use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
30-
use tracing::debug;
31+
use tracing::{debug, trace};
3132

3233
use crate::back::lto::ThinBuffer;
3334
use crate::back::owned_target_machine::OwnedTargetMachine;
@@ -39,7 +40,13 @@ use crate::errors::{
3940
WithLlvmError, WriteBytecode,
4041
};
4142
use crate::llvm::diagnostic::OptimizationDiagnosticKind::*;
42-
use crate::llvm::{self, DiagnosticInfo, PassManager};
43+
use crate::llvm::{
44+
self, AttributeKind, DiagnosticInfo, LLVMCreateStringAttribute, LLVMGetFirstFunction,
45+
LLVMGetNextFunction, LLVMGetStringAttributeAtIndex, LLVMIsEnumAttribute, LLVMIsStringAttribute,
46+
LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex,
47+
LLVMRustAddFunctionAttributes, LLVMRustGetEnumAttributeAtIndex,
48+
LLVMRustRemoveEnumAttributeAtIndex, PassManager,
49+
};
4350
use crate::type_::Type;
4451
use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util};
4552

@@ -515,9 +522,34 @@ pub(crate) unsafe fn llvm_optimize(
515522
config: &ModuleConfig,
516523
opt_level: config::OptLevel,
517524
opt_stage: llvm::OptStage,
525+
skip_size_increasing_opts: bool,
518526
) -> Result<(), FatalError> {
519-
let unroll_loops =
520-
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
527+
// Enzyme:
528+
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
529+
// source code. However, benchmarks show that optimizations increasing the code size
530+
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
531+
// and finally re-optimize the module, now with all optimizations available.
532+
// TODO: In a future update we could figure out how to only optimize functions getting
533+
// differentiated.
534+
535+
let unroll_loops;
536+
let vectorize_slp;
537+
let vectorize_loop;
538+
539+
if skip_size_increasing_opts {
540+
unroll_loops = false;
541+
vectorize_slp = false;
542+
vectorize_loop = false;
543+
} else {
544+
unroll_loops =
545+
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
546+
vectorize_slp = config.vectorize_slp;
547+
vectorize_loop = config.vectorize_loop;
548+
}
549+
trace!(
550+
"Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}",
551+
unroll_loops, vectorize_slp, vectorize_loop
552+
);
521553
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
522554
let pgo_gen_path = get_pgo_gen_path(config);
523555
let pgo_use_path = get_pgo_use_path(config);
@@ -581,8 +613,8 @@ pub(crate) unsafe fn llvm_optimize(
581613
using_thin_buffers,
582614
config.merge_functions,
583615
unroll_loops,
584-
config.vectorize_slp,
585-
config.vectorize_loop,
616+
vectorize_slp,
617+
vectorize_loop,
586618
config.no_builtins,
587619
config.emit_lifetime_markers,
588620
sanitizer_options.as_ref(),
@@ -605,6 +637,113 @@ pub(crate) unsafe fn llvm_optimize(
605637
result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
606638
}
607639

640+
pub(crate) fn differentiate(
641+
module: &ModuleCodegen<ModuleLlvm>,
642+
cgcx: &CodegenContext<LlvmCodegenBackend>,
643+
diff_items: Vec<AutoDiffItem>,
644+
config: &ModuleConfig,
645+
) -> Result<(), FatalError> {
646+
for item in &diff_items {
647+
trace!("{}", item);
648+
}
649+
650+
let llmod = module.module_llvm.llmod();
651+
let llcx = &module.module_llvm.llcx;
652+
let diag_handler = cgcx.create_dcx();
653+
654+
// Before dumping the module, we want all the tt to become part of the module.
655+
for item in diff_items.iter() {
656+
let name = CString::new(item.source.clone()).unwrap();
657+
let fn_def: Option<&llvm::Value> =
658+
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) };
659+
let fn_def = match fn_def {
660+
Some(x) => x,
661+
None => {
662+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
663+
src: item.source.clone(),
664+
target: item.target.clone(),
665+
error: "could not find source function".to_owned(),
666+
}));
667+
}
668+
};
669+
let tgt_name = CString::new(item.target.clone()).unwrap();
670+
dbg!("Target name: {:?}", &tgt_name);
671+
let fn_target: Option<&llvm::Value> =
672+
unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()) };
673+
let fn_target = match fn_target {
674+
Some(x) => x,
675+
None => {
676+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
677+
src: item.source.clone(),
678+
target: item.target.clone(),
679+
error: "could not find target function".to_owned(),
680+
}));
681+
}
682+
};
683+
684+
crate::builder::add_opt_dbg_helper2(llmod, llcx, fn_def, fn_target, item.attrs.clone());
685+
}
686+
687+
// We needed the SanitizeHWAddress attribute to prevent LLVM from optimizing enums in a way
688+
// which Enzyme doesn't understand.
689+
unsafe {
690+
let mut f = LLVMGetFirstFunction(llmod);
691+
loop {
692+
if let Some(lf) = f {
693+
f = LLVMGetNextFunction(lf);
694+
let myhwattr = "enzyme_hw";
695+
let attr = LLVMGetStringAttributeAtIndex(
696+
lf,
697+
c_uint::MAX,
698+
myhwattr.as_ptr() as *const c_char,
699+
myhwattr.as_bytes().len() as c_uint,
700+
);
701+
if LLVMIsStringAttribute(attr) {
702+
LLVMRemoveStringAttributeAtIndex(
703+
lf,
704+
c_uint::MAX,
705+
myhwattr.as_ptr() as *const c_char,
706+
myhwattr.as_bytes().len() as c_uint,
707+
);
708+
} else {
709+
LLVMRustRemoveEnumAttributeAtIndex(
710+
lf,
711+
c_uint::MAX,
712+
AttributeKind::SanitizeHWAddress,
713+
);
714+
}
715+
} else {
716+
break;
717+
}
718+
}
719+
}
720+
721+
if let Some(opt_level) = config.opt_level {
722+
let opt_stage = match cgcx.lto {
723+
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
724+
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
725+
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
726+
_ => llvm::OptStage::PreLinkNoLTO,
727+
};
728+
let skip_size_increasing_opts = false;
729+
dbg!("Running Module Optimization after differentiation");
730+
unsafe {
731+
llvm_optimize(
732+
cgcx,
733+
diag_handler.handle(),
734+
module,
735+
config,
736+
opt_level,
737+
opt_stage,
738+
skip_size_increasing_opts,
739+
)?
740+
};
741+
}
742+
dbg!("Done with differentiate()");
743+
744+
Ok(())
745+
}
746+
608747
// Unsafe due to LLVM calls.
609748
pub(crate) unsafe fn optimize(
610749
cgcx: &CodegenContext<LlvmCodegenBackend>,
@@ -627,14 +766,68 @@ pub(crate) unsafe fn optimize(
627766
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
628767
}
629768

769+
// This code enables Enzyme to differentiate code containing Rust enums.
770+
// By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing
771+
// away the enums and allows Enzyme to understand why a value can be of different types in
772+
// different code sections. We remove this attribute after Enzyme is done, to not affect the
773+
// rest of the compilation.
774+
#[cfg(llvm_enzyme)]
775+
unsafe {
776+
let mut f = LLVMGetFirstFunction(llmod);
777+
loop {
778+
if let Some(lf) = f {
779+
f = LLVMGetNextFunction(lf);
780+
let myhwattr = "enzyme_hw";
781+
let myhwv = "";
782+
let prevattr = LLVMRustGetEnumAttributeAtIndex(
783+
lf,
784+
c_uint::MAX,
785+
AttributeKind::SanitizeHWAddress,
786+
);
787+
if LLVMIsEnumAttribute(prevattr) {
788+
let attr = LLVMCreateStringAttribute(
789+
llcx,
790+
myhwattr.as_ptr() as *const c_char,
791+
myhwattr.as_bytes().len() as c_uint,
792+
myhwv.as_ptr() as *const c_char,
793+
myhwv.as_bytes().len() as c_uint,
794+
);
795+
LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1);
796+
} else {
797+
LLVMRustAddEnumAttributeAtIndex(
798+
llcx,
799+
lf,
800+
c_uint::MAX,
801+
AttributeKind::SanitizeHWAddress,
802+
);
803+
}
804+
} else {
805+
break;
806+
}
807+
}
808+
}
809+
630810
if let Some(opt_level) = config.opt_level {
631811
let opt_stage = match cgcx.lto {
632812
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
633813
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
634814
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
635815
_ => llvm::OptStage::PreLinkNoLTO,
636816
};
637-
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };
817+
818+
// If we know that we will later run AD, then we disable vectorization and loop unrolling
819+
let skip_size_increasing_opts = cfg!(llvm_enzyme);
820+
return unsafe {
821+
llvm_optimize(
822+
cgcx,
823+
dcx,
824+
module,
825+
config,
826+
opt_level,
827+
opt_stage,
828+
skip_size_increasing_opts,
829+
)
830+
};
638831
}
639832
Ok(())
640833
}

0 commit comments

Comments
 (0)