|
| 1 | +use std::ptr; |
| 2 | + |
| 3 | +use rustc_ast::expand::batch_attrs::{BatchAttrs, BatchItem, BatchActivity}; |
| 4 | +use rustc_codegen_ssa::ModuleCodegen; |
| 5 | +use rustc_codegen_ssa::back::write::ModuleConfig; |
| 6 | +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; |
| 7 | +use rustc_errors::FatalError; |
| 8 | +use rustc_middle::ty::TyCtxt; |
| 9 | +use rustc_session::config::Lto; |
| 10 | +use tracing::{debug, trace}; |
| 11 | + |
| 12 | +use crate::back::write::{llvm_err, llvm_optimize}; |
| 13 | +use crate::builder::Builder; |
| 14 | +use crate::declare::declare_raw_fn; |
| 15 | +use crate::errors::LlvmError; |
| 16 | +use crate::llvm::AttributePlace::Function; |
| 17 | +use crate::llvm::{Metadata, True}; |
| 18 | +use crate::value::Value; |
| 19 | +use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, context, llvm}; |
| 20 | + |
| 21 | +fn get_params(fnc: &Value) -> Vec<&Value> { |
| 22 | + unsafe { |
| 23 | + let param_num = llvm::LLVMCountParams(fnc) as usize; |
| 24 | + let mut fnc_args: Vec<&Value> = vec![]; |
| 25 | + fnc_args.reserve(param_num); |
| 26 | + llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); |
| 27 | + fnc_args.set_len(param_num); |
| 28 | + fnc_args |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another |
| 33 | +/// function with expected naming and calling conventions[^1] which will be |
| 34 | +/// discovered by the enzyme LLVM pass and its body populated with the differentiated |
| 35 | +/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated |
| 36 | +/// function and handle the differences between the Rust calling convention and |
| 37 | +/// Enzyme. |
| 38 | +/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/> |
| 39 | +// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to |
| 40 | +// cover some assumptions of enzyme/batch, which could lead to UB otherwise. |
| 41 | +fn generate_enzyme_call<'ll, 'tcx>( |
| 42 | + cx: &context::CodegenCx<'ll, 'tcx>, |
| 43 | + fn_to_diff: &'ll Value, |
| 44 | + outer_fn: &'ll Value, |
| 45 | + attrs: BatchAttrs, |
| 46 | +) { |
| 47 | + let inputs = attrs.input_activity; |
| 48 | + let width = attrs.width; |
| 49 | + let mut ad_name: String = "__enzyme_batch".to_string(); |
| 50 | + |
| 51 | + // add outer_fn name to ad_name to make it unique, in case users apply batch to multiple |
| 52 | + // functions. Unwrap will only panic, if LLVM gave us an invalid string. |
| 53 | + let name = llvm::get_value_name(outer_fn); |
| 54 | + let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap(); |
| 55 | + ad_name.push_str(outer_fn_name.to_string().as_str()); |
| 56 | + |
| 57 | + // Let us assume the user wrote the following function square: |
| 58 | + // |
| 59 | + // ```llvm |
| 60 | + // define double @square(double %x) { |
| 61 | + // entry: |
| 62 | + // %0 = fmul double %x, %x |
| 63 | + // ret double %0 |
| 64 | + // } |
| 65 | + // ``` |
| 66 | + // |
| 67 | + // The user now applies batching to the function square, in which case fn_to_diff will be `square`. |
| 68 | + // Our macro generates the following placeholder code (slightly simplified): |
| 69 | + // |
| 70 | + // ```llvm |
| 71 | + // define double @dsquare(double %x) { |
| 72 | + // ; placeholder code |
| 73 | + // return 0.0; |
| 74 | + // } |
| 75 | + // ``` |
| 76 | + // |
| 77 | + // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder |
| 78 | + // code and inserts an batching call. We also add a declaration for the __enzyme_batch call. |
| 79 | + // Again, the arguments to all functions are slightly simplified. |
| 80 | + // ```llvm |
| 81 | + // declare double @__enzyme_batch_square(...) |
| 82 | + // |
| 83 | + // define double @dsquare(double %x0, double %x1, double %x2, double %x3) { |
| 84 | + // entry: |
| 85 | + // %0 = tail call double (...) @__enzyme_batch_square(double (double)* nonnull @square, metadata !"enzyme_width", i64 4, |
| 86 | + // metadata !"enzyme_vector", double %x0, double %x1, double %x2, double %x3) |
| 87 | + // ret double %0 |
| 88 | + // } |
| 89 | + // ``` |
| 90 | + unsafe { |
| 91 | + // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input |
| 92 | + // arguments. We do however need to declare them with their correct return type. |
| 93 | + // We already figured the correct return type out in our frontend, when generating the outer_fn, |
| 94 | + // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet. |
| 95 | + let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn); |
| 96 | + let ret_ty = llvm::LLVMGetReturnType(fn_ty); |
| 97 | + |
| 98 | + // LLVM can figure out the input types on it's own, so we take a shortcut here. |
| 99 | + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); |
| 100 | + |
| 101 | + //FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and |
| 102 | + // think a bit more about what should go here. |
| 103 | + let cc = llvm::LLVMGetFunctionCallConv(outer_fn); |
| 104 | + let ad_fn = declare_raw_fn( |
| 105 | + cx, |
| 106 | + &ad_name, |
| 107 | + llvm::CallConv::try_from(cc).expect("invalid callconv"), |
| 108 | + llvm::UnnamedAddr::No, |
| 109 | + llvm::Visibility::Default, |
| 110 | + enzyme_ty, |
| 111 | + ); |
| 112 | + |
| 113 | + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to |
| 114 | + // do it's work. |
| 115 | + let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); |
| 116 | + attributes::apply_to_llfn(ad_fn, Function, &[attr]); |
| 117 | + |
| 118 | + // first, remove all calls from fnc |
| 119 | + let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); |
| 120 | + let br = llvm::LLVMRustGetTerminator(entry); |
| 121 | + llvm::LLVMRustEraseInstFromParent(br); |
| 122 | + |
| 123 | + let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); |
| 124 | + let mut builder = Builder::build(cx, entry); |
| 125 | + |
| 126 | + let num_args = llvm::LLVMCountParams(&fn_to_diff); |
| 127 | + let mut args = Vec::with_capacity(num_args as usize + 1); |
| 128 | + args.push(fn_to_diff); |
| 129 | + |
| 130 | + let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); |
| 131 | + let enzyme_vector = cx.create_metadata("enzyme_vector".to_string()).unwrap(); |
| 132 | + let enzyme_buffer = cx.create_metadata("enzyme_buffer".to_string()).unwrap(); |
| 133 | + |
| 134 | + trace!("matching batch arguments"); |
| 135 | + // We now handle the issue that Rust level arguments not always match the llvm-ir level |
| 136 | + // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on |
| 137 | + // llvm-ir level. The number of activities matches the number of Rust level arguments, so we |
| 138 | + // need to match those. |
| 139 | + // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it |
| 140 | + // using iterators and peek()? |
| 141 | + let mut outer_pos: usize = 0; |
| 142 | + let mut activity_pos = 0; |
| 143 | + let outer_args: Vec<&llvm::Value> = get_params(outer_fn); |
| 144 | + while activity_pos < inputs.len() { |
| 145 | + let activity = inputs[activity_pos]; |
| 146 | + let (activity, vectorized): (&Metadata, bool) = match activity { |
| 147 | + BatchActivity::Const => (enzyme_const, false), |
| 148 | + BatchActivity::Vector => (enzyme_vector, true), |
| 149 | + BatchActivity::Leaf => (enzyme_buffer, false), |
| 150 | + BatchActivity::FakeActivitySize => (enzyme_const, false), |
| 151 | + }; |
| 152 | + let outer_arg = outer_args[outer_pos]; |
| 153 | + args.push(cx.get_metadata_value(activity)); |
| 154 | + args.push(outer_arg); |
| 155 | + if vectorized { |
| 156 | + // We know that vectorized args by construction have <width-1> following arguments, |
| 157 | + // so this can not be out of bounds. |
| 158 | + let next_outer_arg = outer_args[outer_pos + width - 1]; |
| 159 | + let next_outer_ty = cx.val_ty(next_outer_arg); |
| 160 | + // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since |
| 161 | + // vectors behind references (&Vec<T>) are already supported. Users can not pass a |
| 162 | + // Vec by value for reverse mode, so this would only help forward mode batch. |
| 163 | + let slice = { |
| 164 | + if activity_pos + 1 >= inputs.len() { |
| 165 | + // If there is no arg following our ptr, it also can't be a slice, |
| 166 | + // since that would lead to a ptr, int pair. |
| 167 | + false |
| 168 | + } else { |
| 169 | + let next_activity = inputs[activity_pos + 1]; |
| 170 | + // We analyze the MIR types and add this dummy activity if we visit a slice. |
| 171 | + next_activity == BatchActivity::FakeActivitySize |
| 172 | + } |
| 173 | + }; |
| 174 | + if slice { |
| 175 | + // A 4x batched slice will have the following two outer_fn arguments: |
| 176 | + // (..., ptr0, int0, ptr1, int1, ...). We add the following llvm-ir to our __enzyme call: |
| 177 | + // (..., metadata! enzyme_vector, ptr0, ptr1, ptr2, ptr3, int1, ...). |
| 178 | + // FIXME(ZuseZ4): We will upstream a safety check later which asserts that |
| 179 | + // int2 >= int1, which means the shadow args are equally large |
| 180 | + |
| 181 | + args.push(cx.get_metadata_value(enzyme_const)); |
| 182 | + // Now we verify that we have width pairs of (ptr/int) |
| 183 | + for i in 0..width { |
| 184 | + let next_outer_arg = outer_args[outer_pos + 2 * i]; |
| 185 | + let next_outer_ty = cx.val_ty(next_outer_arg); |
| 186 | + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); |
| 187 | + let next_outer_arg2 = outer_args[outer_pos + 2 * i + 1]; |
| 188 | + let next_outer_ty2 = cx.val_ty(next_outer_arg2); |
| 189 | + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Integer); |
| 190 | + args.push(next_outer_arg); |
| 191 | + args.push(next_outer_arg2); |
| 192 | + } |
| 193 | + args.push(cx.get_metadata_value(enzyme_const)); |
| 194 | + args.push(next_outer_arg); |
| 195 | + outer_pos += 4; |
| 196 | + activity_pos += 2; |
| 197 | + } else { |
| 198 | + // A vectorized pointer will have the following two outer_fn arguments: |
| 199 | + // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: |
| 200 | + // (..., metadata! enzyme_dup, ptr, ptr, ...). |
| 201 | + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); |
| 202 | + args.push(next_outer_arg); |
| 203 | + outer_pos += 2; |
| 204 | + activity_pos += 1; |
| 205 | + } |
| 206 | + } else { |
| 207 | + // We do not differentiate with resprect to this argument. |
| 208 | + // We already added the metadata and argument above, so just increase the counters. |
| 209 | + outer_pos += 1; |
| 210 | + activity_pos += 1; |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); |
| 215 | + |
| 216 | + // This part is a bit iffy. LLVM requires that a call to an inlineable function has some |
| 217 | + // metadata attachted to it, but we just created this code oota. Given that the |
| 218 | + // differentiated function already has partly confusing metadata, and given that this |
| 219 | + // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the |
| 220 | + // dummy code which we inserted at a higher level. |
| 221 | + // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, |
| 222 | + // and how to best improve it for enzyme core and rust-enzyme. |
| 223 | + let md_ty = cx.get_md_kind_id("dbg"); |
| 224 | + if llvm::LLVMRustHasMetadata(last_inst, md_ty) { |
| 225 | + let md = llvm::LLVMRustDIGetInstMetadata(last_inst) |
| 226 | + .expect("failed to get instruction metadata"); |
| 227 | + let md_todiff = cx.get_metadata_value(md); |
| 228 | + llvm::LLVMSetMetadata(call, md_ty, md_todiff); |
| 229 | + } else { |
| 230 | + // We don't panic, since depending on whether we are in debug or release mode, we might |
| 231 | + // have no debug info to copy, which would then be ok. |
| 232 | + trace!("no dbg info"); |
| 233 | + } |
| 234 | + // Now that we copied the metadata, get rid of dummy code. |
| 235 | + llvm::LLVMRustEraseInstBefore(entry, last_inst); |
| 236 | + llvm::LLVMRustEraseInstFromParent(last_inst); |
| 237 | + |
| 238 | + if cx.val_ty(outer_fn) != cx.type_void() { |
| 239 | + builder.ret(call); |
| 240 | + } else { |
| 241 | + builder.ret_void(); |
| 242 | + } |
| 243 | + |
| 244 | + // Let's crash in case that we messed something up above and generated invalid IR. |
| 245 | + llvm::LLVMRustVerifyFunction( |
| 246 | + outer_fn, |
| 247 | + llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, |
| 248 | + ); |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +pub(crate) fn batch<'ll, 'tcx>( |
| 253 | + module: &'ll ModuleCodegen<ModuleLlvm>, |
| 254 | + cgcx: &CodegenContext<LlvmCodegenBackend>, |
| 255 | + tcx: TyCtxt<'tcx>, |
| 256 | + batch_items: Vec<BatchItem>, |
| 257 | + config: &ModuleConfig, |
| 258 | +) -> Result<(), FatalError> { |
| 259 | + for item in &batch_items { |
| 260 | + trace!("{}", item); |
| 261 | + } |
| 262 | + |
| 263 | + let diag_handler = cgcx.create_dcx(); |
| 264 | + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); |
| 265 | + let cx = context::CodegenCx::new(tcx, &cgus.first().unwrap(), &module.module_llvm); |
| 266 | + |
| 267 | + // Before dumping the module, we want all the TypeTrees to become part of the module. |
| 268 | + for item in batch_items.iter() { |
| 269 | + let name = item.source.clone(); |
| 270 | + let fn_def: Option<&llvm::Value> = cx.get_function(&name); |
| 271 | + let Some(fn_def) = fn_def else { |
| 272 | + return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareBatching { |
| 273 | + src: item.source.clone(), |
| 274 | + target: item.target.clone(), |
| 275 | + error: "could not find source function".to_owned(), |
| 276 | + })); |
| 277 | + }; |
| 278 | + debug!(?item.target); |
| 279 | + let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); |
| 280 | + let Some(fn_target) = fn_target else { |
| 281 | + return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareBatching { |
| 282 | + src: item.source.clone(), |
| 283 | + target: item.target.clone(), |
| 284 | + error: "could not find target function".to_owned(), |
| 285 | + })); |
| 286 | + }; |
| 287 | + |
| 288 | + generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); |
| 289 | + } |
| 290 | + |
| 291 | + // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts |
| 292 | + |
| 293 | + if let Some(opt_level) = config.opt_level { |
| 294 | + let opt_stage = match cgcx.lto { |
| 295 | + Lto::Fat => llvm::OptStage::PreLinkFatLTO, |
| 296 | + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, |
| 297 | + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, |
| 298 | + _ => llvm::OptStage::PreLinkNoLTO, |
| 299 | + }; |
| 300 | + // This is our second opt call, so now we run all opts, |
| 301 | + // to make sure we get the best performance. |
| 302 | + let skip_size_increasing_opts = false; |
| 303 | + trace!("running Module Optimization after differentiation"); |
| 304 | + unsafe { |
| 305 | + llvm_optimize( |
| 306 | + cgcx, |
| 307 | + diag_handler.handle(), |
| 308 | + module, |
| 309 | + config, |
| 310 | + opt_level, |
| 311 | + opt_stage, |
| 312 | + skip_size_increasing_opts, |
| 313 | + )? |
| 314 | + }; |
| 315 | + } |
| 316 | + trace!("done with differentiate()"); |
| 317 | + |
| 318 | + Ok(()) |
| 319 | +} |
0 commit comments