Skip to content

Move coroutine upvars into locals for better memory economy #135527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions compiler/rustc_abi/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,15 @@ impl<Cx: HasDataLayout> LayoutCalculator<Cx> {
>(
&self,
local_layouts: &IndexSlice<LocalIdx, F>,
prefix_layouts: IndexVec<FieldIdx, F>,
relocated_upvars: &IndexSlice<LocalIdx, Option<LocalIdx>>,
variant_fields: &IndexSlice<VariantIdx, IndexVec<FieldIdx, LocalIdx>>,
storage_conflicts: &BitMatrix<LocalIdx, LocalIdx>,
tag_to_layout: impl Fn(Scalar) -> F,
) -> LayoutCalculatorResult<FieldIdx, VariantIdx, F> {
coroutine::layout(
self,
local_layouts,
prefix_layouts,
relocated_upvars,
variant_fields,
storage_conflicts,
tag_to_layout,
Expand Down
83 changes: 51 additions & 32 deletions compiler/rustc_abi/src/layout/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub(super) fn layout<
>(
calc: &super::LayoutCalculator<impl HasDataLayout>,
local_layouts: &IndexSlice<LocalIdx, F>,
mut prefix_layouts: IndexVec<FieldIdx, F>,
relocated_upvars: &IndexSlice<LocalIdx, Option<LocalIdx>>,
variant_fields: &IndexSlice<VariantIdx, IndexVec<FieldIdx, LocalIdx>>,
storage_conflicts: &BitMatrix<LocalIdx, LocalIdx>,
tag_to_layout: impl Fn(Scalar) -> F,
Expand All @@ -155,10 +155,8 @@ pub(super) fn layout<
let (ineligible_locals, assignments) =
coroutine_saved_local_eligibility(local_layouts.len(), variant_fields, storage_conflicts);

// Build a prefix layout, including "promoting" all ineligible
// locals as part of the prefix. We compute the layout of all of
// these fields at once to get optimal packing.
let tag_index = prefix_layouts.len();
// Build a prefix layout, consisting of only the state tag.
let tag_index = 0;

// `variant_fields` already accounts for the reserved variants, so no need to add them.
let max_discr = (variant_fields.len() - 1) as u128;
Expand All @@ -169,17 +167,17 @@ pub(super) fn layout<
};

let promoted_layouts = ineligible_locals.iter().map(|local| local_layouts[local]);
prefix_layouts.push(tag_to_layout(tag));
prefix_layouts.extend(promoted_layouts);
let prefix_layouts: IndexVec<_, _> =
[tag_to_layout(tag)].into_iter().chain(promoted_layouts).collect();
let prefix =
calc.univariant(&prefix_layouts, &ReprOptions::default(), StructKind::AlwaysSized)?;

let (prefix_size, prefix_align) = (prefix.size, prefix.align);

// Split the prefix layout into the "outer" fields (upvars and
// discriminant) and the "promoted" fields. Promoted fields will
// get included in each variant that requested them in
// CoroutineLayout.
// Split the prefix layout into the discriminant and
// the "promoted" fields.
// Promoted fields will get included in each variant
// that requested them in CoroutineLayout.
debug!("prefix = {:#?}", prefix);
let (outer_fields, promoted_offsets, promoted_memory_index) = match prefix.fields {
FieldsShape::Arbitrary { mut offsets, memory_index } => {
Expand Down Expand Up @@ -218,19 +216,45 @@ pub(super) fn layout<
let variants = variant_fields
.iter_enumerated()
.map(|(index, variant_fields)| {
let is_unresumed = index == VariantIdx::new(0);
let mut is_ineligible = IndexVec::from_elem_n(None, variant_fields.len());
for (field, &local) in variant_fields.iter_enumerated() {
if is_unresumed {
// NOTE(@dingxiangfei2009): rewrite this when let-chain #53667
// is stabilized
if let Some(inner_local) = relocated_upvars[local] {
if let Ineligible(Some(promoted_field)) = assignments[inner_local] {
is_ineligible.insert(field, promoted_field);
continue;
}
}
}
match assignments[local] {
Assigned(v) if v == index => {}
Ineligible(Some(promoted_field)) => {
is_ineligible.insert(field, promoted_field);
}
Ineligible(None) => {
panic!("an ineligible local should have been promoted into the prefix")
}
Assigned(_) => {
panic!("an eligible local should have been assigned to exactly one variant")
}
Unassigned => {
panic!("each saved local should have been inspected at least once")
}
}
}
// Only include overlap-eligible fields when we compute our variant layout.
let variant_only_tys = variant_fields
.iter()
.filter(|local| match assignments[**local] {
Unassigned => unreachable!(),
Assigned(v) if v == index => true,
Assigned(_) => unreachable!("assignment does not match variant"),
Ineligible(_) => false,
let fields: IndexVec<_, _> = variant_fields
.iter_enumerated()
.filter_map(|(field, &local)| {
if is_ineligible.contains(field) { None } else { Some(local_layouts[local]) }
})
.map(|local| local_layouts[*local]);
.collect();

let mut variant = calc.univariant(
&variant_only_tys.collect::<IndexVec<_, _>>(),
&fields,
&ReprOptions::default(),
StructKind::Prefixed(prefix_size, prefix_align.abi),
)?;
Expand All @@ -254,19 +278,14 @@ pub(super) fn layout<
IndexVec::from_elem_n(FieldIdx::new(invalid_field_idx), invalid_field_idx);

let mut offsets_and_memory_index = iter::zip(offsets, memory_index);
let combined_offsets = variant_fields
let combined_offsets = is_ineligible
.iter_enumerated()
.map(|(i, local)| {
let (offset, memory_index) = match assignments[*local] {
Unassigned => unreachable!(),
Assigned(_) => {
let (offset, memory_index) = offsets_and_memory_index.next().unwrap();
(offset, promoted_memory_index.len() as u32 + memory_index)
}
Ineligible(field_idx) => {
let field_idx = field_idx.unwrap();
(promoted_offsets[field_idx], promoted_memory_index[field_idx])
}
.map(|(i, &is_ineligible)| {
let (offset, memory_index) = if let Some(field_idx) = is_ineligible {
(promoted_offsets[field_idx], promoted_memory_index[field_idx])
} else {
let (offset, memory_index) = offsets_and_memory_index.next().unwrap();
(offset, promoted_memory_index.len() as u32 + memory_index)
};
combined_inverse_memory_index[memory_index] = i;
offset
Expand Down
105 changes: 63 additions & 42 deletions compiler/rustc_borrowck/src/diagnostics/mutability_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,49 +393,18 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
Place::ty_from(local, proj_base, self.body, self.infcx.tcx).ty
));

let captured_place = self.upvars[upvar_index.index()];

err.span_label(span, format!("cannot {act}"));

let upvar_hir_id = captured_place.get_root_variable();

if let Node::Pat(pat) = self.infcx.tcx.hir_node(upvar_hir_id)
&& let hir::PatKind::Binding(hir::BindingMode::NONE, _, upvar_ident, _) =
pat.kind
{
if upvar_ident.name == kw::SelfLower {
for (_, node) in self.infcx.tcx.hir_parent_iter(upvar_hir_id) {
if let Some(fn_decl) = node.fn_decl() {
if !matches!(
fn_decl.implicit_self,
hir::ImplicitSelfKind::RefImm | hir::ImplicitSelfKind::RefMut
) {
err.span_suggestion_verbose(
upvar_ident.span.shrink_to_lo(),
"consider changing this to be mutable",
"mut ",
Applicability::MachineApplicable,
);
break;
}
}
}
} else {
err.span_suggestion_verbose(
upvar_ident.span.shrink_to_lo(),
"consider changing this to be mutable",
"mut ",
Applicability::MachineApplicable,
);
}
}
self.suggest_mutable_upvar(*upvar_index, the_place_err, &mut err, span, act);
}

let tcx = self.infcx.tcx;
if let ty::Ref(_, ty, Mutability::Mut) = the_place_err.ty(self.body, tcx).ty.kind()
&& let ty::Closure(id, _) = *ty.kind()
{
self.show_mutating_upvar(tcx, id.expect_local(), the_place_err, &mut err);
}
PlaceRef { local, projection: [] }
if let Some(upvar_index) = self
.body
.local_upvar_map
.iter_enumerated()
.filter_map(|(field, &local)| local.map(|local| (field, local)))
.find_map(|(field, relocated)| (relocated == local).then_some(field)) =>
{
self.suggest_mutable_upvar(upvar_index, the_place_err, &mut err, span, act);
}

// complete hack to approximate old AST-borrowck
Expand Down Expand Up @@ -542,6 +511,58 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
}
}

fn suggest_mutable_upvar(
&self,
upvar_index: FieldIdx,
the_place_err: PlaceRef<'tcx>,
err: &mut Diag<'infcx>,
span: Span,
act: &str,
) {
let captured_place = self.upvars[upvar_index.index()];

err.span_label(span, format!("cannot {act}"));

let upvar_hir_id = captured_place.get_root_variable();

if let Node::Pat(pat) = self.infcx.tcx.hir_node(upvar_hir_id)
&& let hir::PatKind::Binding(hir::BindingMode::NONE, _, upvar_ident, _) = pat.kind
{
if upvar_ident.name == kw::SelfLower {
for (_, node) in self.infcx.tcx.hir_parent_iter(upvar_hir_id) {
if let Some(fn_decl) = node.fn_decl() {
if !matches!(
fn_decl.implicit_self,
hir::ImplicitSelfKind::RefImm | hir::ImplicitSelfKind::RefMut
) {
err.span_suggestion_verbose(
upvar_ident.span.shrink_to_lo(),
"consider changing this to be mutable",
"mut ",
Applicability::MachineApplicable,
);
break;
}
}
}
} else {
err.span_suggestion_verbose(
upvar_ident.span.shrink_to_lo(),
"consider changing this to be mutable",
"mut ",
Applicability::MachineApplicable,
);
}
}

let tcx = self.infcx.tcx;
if let ty::Ref(_, ty, Mutability::Mut) = the_place_err.ty(self.body, tcx).ty.kind()
&& let ty::Closure(id, _) = *ty.kind()
{
self.show_mutating_upvar(tcx, id.expect_local(), the_place_err, err);
}
}

/// Suggest `map[k] = v` => `map.insert(k, v)` and the like.
fn suggest_map_index_mut_alternatives(&self, ty: Ty<'tcx>, err: &mut Diag<'infcx>, span: Span) {
let Some(adt) = ty.ty_adt_def() else { return };
Expand Down
27 changes: 22 additions & 5 deletions compiler/rustc_borrowck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::ops::{ControlFlow, Deref};
use rustc_abi::FieldIdx;
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_data_structures::graph::dominators::Dominators;
use rustc_data_structures::unord::UnordMap;
use rustc_errors::LintDiagnostic;
use rustc_hir as hir;
use rustc_hir::CRATE_HIR_ID;
Expand Down Expand Up @@ -261,6 +262,7 @@ fn do_mir_borrowck<'tcx>(
regioncx: &regioncx,
used_mut: Default::default(),
used_mut_upvars: SmallVec::new(),
local_from_upvars: UnordMap::default(),
borrow_set: &borrow_set,
upvars: &[],
local_names: IndexVec::from_elem(None, &promoted_body.local_decls),
Expand All @@ -286,6 +288,11 @@ fn do_mir_borrowck<'tcx>(
promoted_mbcx.report_move_errors();
}

let mut local_from_upvars = UnordMap::default();
for (field, &local) in body.local_upvar_map.iter_enumerated() {
let Some(local) = local else { continue };
local_from_upvars.insert(local, field);
}
let mut mbcx = MirBorrowckCtxt {
infcx: &infcx,
body,
Expand All @@ -300,6 +307,7 @@ fn do_mir_borrowck<'tcx>(
regioncx: &regioncx,
used_mut: Default::default(),
used_mut_upvars: SmallVec::new(),
local_from_upvars,
borrow_set: &borrow_set,
upvars: tcx.closure_captures(def),
local_names,
Expand Down Expand Up @@ -555,6 +563,9 @@ struct MirBorrowckCtxt<'a, 'infcx, 'tcx> {
/// If the function we're checking is a closure, then we'll need to report back the list of
/// mutable upvars that have been used. This field keeps track of them.
used_mut_upvars: SmallVec<[FieldIdx; 8]>,
/// Since upvars are moved to real locals, we need to map mutations to the locals back to
/// the upvars, so that used_mut_upvars is up-to-date.
local_from_upvars: UnordMap<Local, FieldIdx>,
/// Region inference context. This contains the results from region inference and lets us e.g.
/// find out which CFG points are contained in each borrow region.
regioncx: &'a RegionInferenceContext<'tcx>,
Expand Down Expand Up @@ -2265,7 +2276,9 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {

// at this point, we have set up the error reporting state.
if let Some(init_index) = previously_initialized {
if let (AccessKind::Mutate, Some(_)) = (error_access, place.as_local()) {
if let (AccessKind::Mutate, Some(local)) = (error_access, place.as_local())
&& self.body.local_upvar_map.iter().flatten().all(|upvar| upvar != &local)
{
// If this is a mutate access to an immutable local variable with no projections
// report the error as an illegal reassignment
let init = &self.move_data.inits[init_index];
Expand Down Expand Up @@ -2293,10 +2306,12 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {
// If the local may have been initialized, and it is now currently being
// mutated, then it is justified to be annotated with the `mut`
// keyword, since the mutation may be a possible reassignment.
if is_local_mutation_allowed != LocalMutationIsAllowed::Yes
&& self.is_local_ever_initialized(local, state).is_some()
{
self.used_mut.insert(local);
if !matches!(is_local_mutation_allowed, LocalMutationIsAllowed::Yes) {
if self.is_local_ever_initialized(local, state).is_some() {
self.used_mut.insert(local);
} else if let Some(&field) = self.local_from_upvars.get(&local) {
self.used_mut_upvars.push(field);
}
}
}
RootPlace {
Expand All @@ -2314,6 +2329,8 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {
projection: place_projection,
}) {
self.used_mut_upvars.push(field);
} else if let Some(&field) = self.local_from_upvars.get(&place_local) {
self.used_mut_upvars.push(field);
}
}
}
Expand Down
23 changes: 12 additions & 11 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ use rustc_middle::traits::query::NoSolution;
use rustc_middle::ty::adjustment::PointerCoercion;
use rustc_middle::ty::cast::CastTy;
use rustc_middle::ty::{
self, Binder, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, CoroutineArgsExt,
Dynamic, GenericArgsRef, OpaqueHiddenType, OpaqueTypeKey, RegionVid, Ty, TyCtxt,
TypeVisitableExt, UserArgs, UserTypeAnnotationIndex, fold_regions,
self, Binder, CanonicalUserTypeAnnotation, CanonicalUserTypeAnnotations, Dynamic,
GenericArgsRef, OpaqueHiddenType, OpaqueTypeKey, RegionVid, Ty, TyCtxt, TypeVisitableExt,
UserArgs, UserTypeAnnotationIndex, fold_regions,
};
use rustc_middle::{bug, span_bug};
use rustc_mir_dataflow::ResultsCursor;
Expand Down Expand Up @@ -2188,14 +2188,15 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
AggregateKind::Coroutine(_, args) => {
// It doesn't make sense to look at a field beyond the prefix;
// these require a variant index, and are not initialized in
// aggregate rvalues.
match args.as_coroutine().prefix_tys().get(field_index.as_usize()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
// It doesn't make sense to look at a field beyond the captured
// upvars.
// Otherwise it require a variant index, and are not initialized
// in aggregate rvalues.
let upvar_tys = &args.as_coroutine().upvar_tys();
if let Some(ty) = upvar_tys.get(field_index.as_usize()) {
Ok(*ty)
} else {
Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() })
}
}
AggregateKind::CoroutineClosure(_, args) => {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,9 @@ fn codegen_stmt<'tcx>(
let variant_dest = lval.downcast_variant(fx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_def_id, _args) => {
(FIRST_VARIANT, lval.downcast_variant(fx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, lval, None),
};
if active_field_index.is_some() {
Expand Down
Loading
Loading