|
| 1 | +use std::cmp::Reverse; |
| 2 | +use std::collections::BTreeSet; |
| 3 | +use std::fmt::Debug; |
| 4 | +use std::ops::Bound; |
| 5 | + |
| 6 | +use rustc_abi::{ |
| 7 | + BackendRepr, FieldIdx, FieldsShape, Integer, Layout, LayoutData, Primitive, Scalar, Size, |
| 8 | + TagEncoding, TyAndLayout, Variants, WrappingRange, |
| 9 | +}; |
| 10 | +use rustc_index::IndexVec; |
| 11 | +use rustc_index::bit_set::BitSet; |
| 12 | +use rustc_middle::mir::CoroutineSavedLocal; |
| 13 | +use rustc_middle::ty::layout::{HasTyCtxt, IntegerExt, LayoutCx, LayoutError, LayoutOf}; |
| 14 | +use rustc_middle::ty::{EarlyBinder, GenericArgsRef, Ty}; |
| 15 | +use tracing::{debug, instrument}; |
| 16 | + |
| 17 | +use super::error; |
| 18 | + |
| 19 | +#[instrument(level = "debug", skip(cx))] |
| 20 | +pub(super) fn coroutine_layout<'tcx>( |
| 21 | + cx: &LayoutCx<'tcx>, |
| 22 | + ty: Ty<'tcx>, |
| 23 | + def_id: rustc_hir::def_id::DefId, |
| 24 | + args: GenericArgsRef<'tcx>, |
| 25 | +) -> Result<Layout<'tcx>, &'tcx LayoutError<'tcx>> { |
| 26 | + let tcx = cx.tcx(); |
| 27 | + let Some(info) = tcx.coroutine_layout(def_id, args.as_coroutine().kind_ty()) else { |
| 28 | + return Err(error(cx, LayoutError::Unknown(ty))); |
| 29 | + }; |
| 30 | + |
| 31 | + let tcx = cx.tcx(); |
| 32 | + let instantiate_field = |ty: Ty<'tcx>| EarlyBinder::bind(ty).instantiate(tcx, args); |
| 33 | + let field_layouts: IndexVec<CoroutineSavedLocal, _> = info |
| 34 | + .field_tys |
| 35 | + .iter_enumerated() |
| 36 | + .map(|(saved_local, ty)| { |
| 37 | + let ty = instantiate_field(ty.ty); |
| 38 | + cx.spanned_layout_of(ty, info.field_tys[saved_local].source_info.span) |
| 39 | + }) |
| 40 | + .try_collect()?; |
| 41 | + let layouts: IndexVec<CoroutineSavedLocal, _> = |
| 42 | + field_layouts.iter().map(|data| data.layout.clone()).collect(); |
| 43 | + |
| 44 | + let field_sort_keys: IndexVec<CoroutineSavedLocal, _>; |
| 45 | + let mut saved_locals: Vec<_>; |
| 46 | + // ## The heuristic on which saved locals get allocation first ## |
| 47 | + // 1. the alignment |
| 48 | + // Intuitively data with a larger alignment asks for a larger block of contiguous memory. |
| 49 | + // It is easier to get large blocks early in the beginning, but it will get harder to |
| 50 | + // recover them as fragmentation creeps in when data with smaller alignment occupies |
| 51 | + // the large chunks. |
| 52 | + // 2. the size |
| 53 | + // The size also poses restriction on layout, but not as potent as alignment. |
| 54 | + // 3. the degree of conflicts |
| 55 | + // This metric is the number of confliciting saved locals with a given saved local. |
| 56 | + // Preferring allocating highly conflicting data over those that are less and more |
| 57 | + // transient in nature will keep the fragmentation contained in neighbourhoods of a layout. |
| 58 | + (saved_locals, field_sort_keys) = field_layouts |
| 59 | + .iter_enumerated() |
| 60 | + .map(|(saved_local, ty)| { |
| 61 | + ( |
| 62 | + saved_local, |
| 63 | + ( |
| 64 | + Reverse(ty.align.abi), |
| 65 | + Reverse(ty.size), |
| 66 | + Reverse(info.storage_conflicts.count(saved_local)), |
| 67 | + ), |
| 68 | + ) |
| 69 | + }) |
| 70 | + .unzip(); |
| 71 | + let mut uninhabited_or_zst = BitSet::new_empty(field_layouts.len()); |
| 72 | + for (saved_local, ty) in field_layouts.iter_enumerated() { |
| 73 | + if ty.layout.backend_repr.is_uninhabited() || ty.layout.is_zst() { |
| 74 | + uninhabited_or_zst.insert(saved_local); |
| 75 | + } |
| 76 | + } |
| 77 | + saved_locals.sort_by_key(|&idx| &field_sort_keys[idx]); |
| 78 | + // This will be *the* align of the entire coroutine |
| 79 | + let max_discr = (info.variant_fields.len() - 1) as u128; |
| 80 | + let discr_int = Integer::fit_unsigned(max_discr); |
| 81 | + let tag = Scalar::Initialized { |
| 82 | + value: Primitive::Int(discr_int, false), |
| 83 | + valid_range: WrappingRange { start: 0, end: max_discr }, |
| 84 | + }; |
| 85 | + let tag_layout = TyAndLayout { |
| 86 | + ty: discr_int.to_ty(tcx, false), |
| 87 | + layout: tcx.mk_layout(LayoutData::scalar(cx, tag)), |
| 88 | + }; |
| 89 | + // We need to also consider the tag layout alignment |
| 90 | + let align = saved_locals |
| 91 | + .get(0) |
| 92 | + .map(|&idx| layouts[idx].align.max(tag_layout.layout.align)) |
| 93 | + .unwrap_or(tag_layout.layout.align); |
| 94 | + |
| 95 | + // ## The blocked map, or the reservation map ## |
| 96 | + // This map from saved locals to memory layout records the reservation |
| 97 | + // status of the coroutine state memory, down to the byte granularity. |
| 98 | + // `Slot`s are inserted to mark ranges of memory that a particular saved local |
| 99 | + // shall not have overlapping memory allocation, due to the liveness of |
| 100 | + // other conflicting saved locals. |
| 101 | + // Therefore, we can try to make reservation for this saved local |
| 102 | + // by inspecting the gaps before, between, and after those blocked-out memory ranges. |
| 103 | + let mut blocked: IndexVec<CoroutineSavedLocal, BTreeSet<Slot>> = |
| 104 | + IndexVec::from_elem_n(BTreeSet::new(), saved_locals.len()); |
| 105 | + let mut tag_blocked = BTreeSet::new(); |
| 106 | + let mut assignment = IndexVec::from_elem_n(Slot { start: 0, end: 0 }, saved_locals.len()); |
| 107 | + for (idx, ¤t_local) in saved_locals.iter().enumerate() { |
| 108 | + if uninhabited_or_zst.contains(current_local) { |
| 109 | + // Do not bother to compute on uninhabited data. |
| 110 | + // They will not get allocation after all. |
| 111 | + // By default, a ZST occupies the beginning of the coroutine state. |
| 112 | + continue; |
| 113 | + } |
| 114 | + let layout_data = &field_layouts[current_local]; |
| 115 | + let size = layout_data.size.bytes(); |
| 116 | + |
| 117 | + let mut candidate = Slot { start: 0, end: size }; |
| 118 | + for slot in blocked[current_local].iter() { |
| 119 | + if slot.overlap_with(&candidate) { |
| 120 | + let start = Size::from_bytes(slot.end).align_to(layout_data.align.abi).bytes(); |
| 121 | + candidate = Slot { start, end: start + size }; |
| 122 | + } else { |
| 123 | + break; |
| 124 | + } |
| 125 | + } |
| 126 | + merge_slot_in(&mut tag_blocked, candidate); |
| 127 | + for &other_local in &saved_locals[idx + 1..] { |
| 128 | + if info.storage_conflicts.contains(current_local, other_local) { |
| 129 | + merge_slot_in(&mut blocked[other_local], candidate); |
| 130 | + } |
| 131 | + } |
| 132 | + // Adjustment to the layout of this field by shifting them into the chosen slot |
| 133 | + assignment[current_local] = candidate; |
| 134 | + } |
| 135 | + debug!(assignment = ?assignment.debug_map_view()); |
| 136 | + |
| 137 | + // Find a slot for discriminant, also known as the tag. |
| 138 | + let tag_size = tag_layout.layout.size.bytes(); |
| 139 | + let mut tag_candidate = Slot { start: 0, end: tag_size }; |
| 140 | + // The discriminant is certainly conflicting with all the saved locals |
| 141 | + for slot in tag_blocked { |
| 142 | + if slot.overlap_with(&tag_candidate) { |
| 143 | + let start = Size::from_bytes(slot.end).align_to(tag_layout.layout.align.abi).bytes(); |
| 144 | + tag_candidate = Slot { start, end: start + tag_size } |
| 145 | + } else { |
| 146 | + break; |
| 147 | + } |
| 148 | + } |
| 149 | + debug!(tag = ?tag_candidate); |
| 150 | + |
| 151 | + // Assemble the layout for each coroutine state |
| 152 | + let variants: IndexVec<_, LayoutData<_, _>> = info |
| 153 | + .variant_fields |
| 154 | + .iter_enumerated() |
| 155 | + .map(|(index, fields)| { |
| 156 | + if fields.iter().any(|&saved_local| uninhabited_or_zst.contains(saved_local)) { |
| 157 | + LayoutData { |
| 158 | + fields: FieldsShape::Arbitrary { offsets: [].into(), memory_index: [].into() }, |
| 159 | + variants: Variants::Single { index }, |
| 160 | + backend_repr: BackendRepr::Uninhabited, |
| 161 | + largest_niche: None, |
| 162 | + align, |
| 163 | + size: Size::ZERO, |
| 164 | + max_repr_align: None, |
| 165 | + unadjusted_abi_align: align.abi, |
| 166 | + } |
| 167 | + } else { |
| 168 | + let size = Size::from_bytes( |
| 169 | + fields |
| 170 | + .iter() |
| 171 | + .map(|&saved_local| assignment[saved_local].end) |
| 172 | + .max() |
| 173 | + .unwrap_or(0) |
| 174 | + .max(tag_candidate.end), |
| 175 | + ) |
| 176 | + .align_to(align.abi); |
| 177 | + let offsets: IndexVec<_, _> = fields |
| 178 | + .iter() |
| 179 | + .map(|&saved_local| Size::from_bytes(assignment[saved_local].start)) |
| 180 | + .collect(); |
| 181 | + let memory_index = |
| 182 | + IndexVec::from_fn_n(|n: FieldIdx| (n.index() as u32), offsets.len()); |
| 183 | + LayoutData { |
| 184 | + // We are aware of specialized layouts such as scalar pairs but this is still |
| 185 | + // in development. |
| 186 | + // Let us hold off from further optimisation until more information is available. |
| 187 | + fields: FieldsShape::Arbitrary { offsets, memory_index }, |
| 188 | + variants: Variants::Single { index }, |
| 189 | + backend_repr: BackendRepr::Memory { sized: true }, |
| 190 | + largest_niche: None, |
| 191 | + align, |
| 192 | + size, |
| 193 | + max_repr_align: None, |
| 194 | + unadjusted_abi_align: align.abi, |
| 195 | + } |
| 196 | + } |
| 197 | + }) |
| 198 | + .collect(); |
| 199 | + let size = variants |
| 200 | + .iter() |
| 201 | + .map(|data| data.size) |
| 202 | + .max() |
| 203 | + .unwrap_or(Size::ZERO) |
| 204 | + .max(Size::from_bytes(tag_candidate.end)) |
| 205 | + .align_to(align.abi); |
| 206 | + let layout = tcx.mk_layout(LayoutData { |
| 207 | + fields: FieldsShape::Arbitrary { |
| 208 | + offsets: [Size::from_bytes(tag_candidate.start)].into(), |
| 209 | + memory_index: [0].into(), |
| 210 | + }, |
| 211 | + variants: Variants::Multiple { |
| 212 | + tag, |
| 213 | + tag_encoding: TagEncoding::Direct, |
| 214 | + tag_field: 0, |
| 215 | + variants, |
| 216 | + }, |
| 217 | + backend_repr: BackendRepr::Memory { sized: true }, |
| 218 | + // Suppress niches inside coroutines. If the niche is inside a field that is aliased (due to |
| 219 | + // self-referentiality), getting the discriminant can cause aliasing violations. |
| 220 | + // `UnsafeCell` blocks niches for the same reason, but we don't yet have `UnsafePinned` that |
| 221 | + // would do the same for us here. |
| 222 | + // See <https://github.com/rust-lang/rust/issues/63818>, <https://github.com/rust-lang/miri/issues/3780>. |
| 223 | + // FIXME(#125735): Remove when <https://github.com/rust-lang/rust/issues/125735> |
| 224 | + // is implemented and aliased coroutine fields are wrapped in `UnsafePinned`. |
| 225 | + // NOTE(@dingxiangfei2009): I believe there is still niche, which is the tag, |
| 226 | + // but I am not sure how much benefit is there for us to grab. |
| 227 | + largest_niche: None, |
| 228 | + align, |
| 229 | + size, |
| 230 | + max_repr_align: None, |
| 231 | + unadjusted_abi_align: align.abi, |
| 232 | + }); |
| 233 | + debug!("coroutine layout ({:?}): {:#?}", ty, layout); |
| 234 | + Ok(layout) |
| 235 | +} |
| 236 | + |
| 237 | +/// An occupied slot in the coroutine memory at some yield point |
| 238 | +#[derive(PartialOrd, PartialEq, Eq, Ord, Copy, Clone)] |
| 239 | +struct Slot { |
| 240 | + /// Beginning of the memory slot, inclusive |
| 241 | + start: u64, |
| 242 | + /// End of the memory slot, exclusive or one byte past the data |
| 243 | + end: u64, |
| 244 | +} |
| 245 | + |
| 246 | +impl Debug for Slot { |
| 247 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 248 | + f.debug_tuple("Slot").field(&self.start).field(&self.end).finish() |
| 249 | + } |
| 250 | +} |
| 251 | + |
| 252 | +impl Slot { |
| 253 | + fn overlap_with(&self, other: &Self) -> bool { |
| 254 | + if self.start == self.end || other.start == other.end { |
| 255 | + return false; |
| 256 | + } |
| 257 | + self.contains_point(other.start) || other.contains_point(self.start) |
| 258 | + } |
| 259 | + fn contains_point(&self, point: u64) -> bool { |
| 260 | + self.start <= point && point < self.end |
| 261 | + } |
| 262 | +} |
| 263 | + |
| 264 | +fn merge_slot_in(slots: &mut BTreeSet<Slot>, slot: Slot) { |
| 265 | + let start = Slot { start: slot.start, end: slot.start }; |
| 266 | + let end = Slot { start: slot.end, end: slot.end }; |
| 267 | + let one_past_end = Slot { start: slot.end + 1, end: slot.end + 1 }; |
| 268 | + let (range_start, replace_start) = if let Some(prev) = slots.range(..start).next_back() |
| 269 | + && (prev.end == slot.start || prev.contains_point(slot.start)) |
| 270 | + { |
| 271 | + (Bound::Included(prev), prev.start) |
| 272 | + } else { |
| 273 | + (Bound::Included(&start), slot.start) |
| 274 | + }; |
| 275 | + let (range_end, replace_end) = if let Some(next) = slots.range(..one_past_end).next_back() |
| 276 | + && next.start == slot.end |
| 277 | + { |
| 278 | + (Bound::Included(next), next.end) |
| 279 | + } else if let Some(prev) = slots.range(..end).next_back() |
| 280 | + && prev.contains_point(slot.end) |
| 281 | + { |
| 282 | + (Bound::Included(prev), prev.end) |
| 283 | + } else { |
| 284 | + (Bound::Included(&end), slot.end) |
| 285 | + }; |
| 286 | + let to_remove: Vec<_> = slots.range((range_start, range_end)).copied().collect(); |
| 287 | + for slot in to_remove { |
| 288 | + slots.remove(&slot); |
| 289 | + } |
| 290 | + slots.insert(Slot { start: replace_start, end: replace_end }); |
| 291 | +} |
0 commit comments