Skip to content

Commit b802668

Browse files
relocate upvars to Unresumed state and make coroutine prefix trivial
Co-authored-by: Dario Nieuwenhuis <[email protected]>
1 parent 51917ba commit b802668

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+954
-402
lines changed

compiler/rustc_borrowck/src/lib.rs

+21-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ extern crate tracing;
1919

2020
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
2121
use rustc_data_structures::graph::dominators::Dominators;
22+
use rustc_data_structures::unord::UnordMap;
2223
use rustc_errors::Diag;
2324
use rustc_hir as hir;
2425
use rustc_hir::def_id::LocalDefId;
@@ -296,6 +297,7 @@ fn do_mir_borrowck<'tcx>(
296297
regioncx: regioncx.clone(),
297298
used_mut: Default::default(),
298299
used_mut_upvars: SmallVec::new(),
300+
local_from_upvars: UnordMap::default(),
299301
borrow_set: Rc::clone(&borrow_set),
300302
upvars: &[],
301303
local_names: IndexVec::from_elem(None, &promoted_body.local_decls),
@@ -322,6 +324,12 @@ fn do_mir_borrowck<'tcx>(
322324
}
323325
}
324326

327+
let mut local_from_upvars = UnordMap::default();
328+
for (field, &local) in body.local_upvar_map.iter_enumerated() {
329+
let Some(local) = local else { continue };
330+
local_from_upvars.insert(local, field);
331+
}
332+
debug!(?local_from_upvars, "dxf");
325333
let mut mbcx = MirBorrowckCtxt {
326334
infcx: &infcx,
327335
param_env,
@@ -337,6 +345,7 @@ fn do_mir_borrowck<'tcx>(
337345
regioncx: Rc::clone(&regioncx),
338346
used_mut: Default::default(),
339347
used_mut_upvars: SmallVec::new(),
348+
local_from_upvars,
340349
borrow_set: Rc::clone(&borrow_set),
341350
upvars: tcx.closure_captures(def),
342351
local_names,
@@ -572,6 +581,9 @@ struct MirBorrowckCtxt<'a, 'mir, 'infcx, 'tcx> {
572581
/// If the function we're checking is a closure, then we'll need to report back the list of
573582
/// mutable upvars that have been used. This field keeps track of them.
574583
used_mut_upvars: SmallVec<[FieldIdx; 8]>,
584+
/// Since upvars are moved to real locals, we need to map mutations to the locals back to
585+
/// the upvars, so that used_mut_upvars is up-to-date.
586+
local_from_upvars: UnordMap<Local, FieldIdx>,
575587
/// Region inference context. This contains the results from region inference and lets us e.g.
576588
/// find out which CFG points are contained in each borrow region.
577589
regioncx: Rc<RegionInferenceContext<'tcx>>,
@@ -2227,16 +2239,19 @@ impl<'mir, 'tcx> MirBorrowckCtxt<'_, 'mir, '_, 'tcx> {
22272239
}
22282240

22292241
/// Adds the place into the used mutable variables set
2242+
#[instrument(level = "debug", skip(self, flow_state))]
22302243
fn add_used_mut(&mut self, root_place: RootPlace<'tcx>, flow_state: &Flows<'_, 'mir, 'tcx>) {
22312244
match root_place {
22322245
RootPlace { place_local: local, place_projection: [], is_local_mutation_allowed } => {
22332246
// If the local may have been initialized, and it is now currently being
22342247
// mutated, then it is justified to be annotated with the `mut`
22352248
// keyword, since the mutation may be a possible reassignment.
2236-
if is_local_mutation_allowed != LocalMutationIsAllowed::Yes
2237-
&& self.is_local_ever_initialized(local, flow_state).is_some()
2238-
{
2239-
self.used_mut.insert(local);
2249+
if !matches!(is_local_mutation_allowed, LocalMutationIsAllowed::Yes) {
2250+
if self.is_local_ever_initialized(local, flow_state).is_some() {
2251+
self.used_mut.insert(local);
2252+
} else if let Some(&field) = self.local_from_upvars.get(&local) {
2253+
self.used_mut_upvars.push(field);
2254+
}
22402255
}
22412256
}
22422257
RootPlace {
@@ -2254,6 +2269,8 @@ impl<'mir, 'tcx> MirBorrowckCtxt<'_, 'mir, '_, 'tcx> {
22542269
projection: place_projection,
22552270
}) {
22562271
self.used_mut_upvars.push(field);
2272+
} else if let Some(&field) = self.local_from_upvars.get(&place_local) {
2273+
self.used_mut_upvars.push(field);
22572274
}
22582275
}
22592276
}

compiler/rustc_borrowck/src/type_check/mod.rs

+12-12
Original file line numberDiff line numberDiff line change
@@ -816,15 +816,15 @@ impl<'a, 'b, 'tcx> TypeVerifier<'a, 'b, 'tcx> {
816816
}),
817817
};
818818
}
819-
ty::Coroutine(_, args) => {
819+
ty::Coroutine(_def_id, args) => {
820820
// Only prefix fields (upvars and current state) are
821821
// accessible without a variant index.
822-
return match args.as_coroutine().prefix_tys().get(field.index()) {
823-
Some(ty) => Ok(*ty),
824-
None => Err(FieldAccessError::OutOfRange {
825-
field_count: args.as_coroutine().prefix_tys().len(),
826-
}),
827-
};
822+
let upvar_tys = args.as_coroutine().upvar_tys();
823+
if let Some(ty) = upvar_tys.get(field.index()) {
824+
return Ok(*ty);
825+
} else {
826+
return Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() });
827+
}
828828
}
829829
ty::Tuple(tys) => {
830830
return match tys.get(field.index()) {
@@ -1828,11 +1828,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
18281828
// It doesn't make sense to look at a field beyond the prefix;
18291829
// these require a variant index, and are not initialized in
18301830
// aggregate rvalues.
1831-
match args.as_coroutine().prefix_tys().get(field_index.as_usize()) {
1832-
Some(ty) => Ok(*ty),
1833-
None => Err(FieldAccessError::OutOfRange {
1834-
field_count: args.as_coroutine().prefix_tys().len(),
1835-
}),
1831+
let upvar_tys = &args.as_coroutine().upvar_tys();
1832+
if let Some(ty) = upvar_tys.get(field_index.as_usize()) {
1833+
Ok(*ty)
1834+
} else {
1835+
Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() })
18361836
}
18371837
}
18381838
AggregateKind::CoroutineClosure(_, args) => {

compiler/rustc_codegen_cranelift/src/base.rs

+3
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,9 @@ fn codegen_stmt<'tcx>(
875875
let variant_dest = lval.downcast_variant(fx, variant_index);
876876
(variant_index, variant_dest, active_field_index)
877877
}
878+
mir::AggregateKind::Coroutine(_def_id, _args) => {
879+
(FIRST_VARIANT, lval.downcast_variant(fx, FIRST_VARIANT), None)
880+
}
878881
_ => (FIRST_VARIANT, lval, None),
879882
};
880883
if active_field_index.is_some() {

compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ use rustc_hir::def_id::{DefId, LOCAL_CRATE};
3131
use rustc_middle::bug;
3232
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
3333
use rustc_middle::ty::{
34-
self, AdtKind, CoroutineArgsExt, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt,
35-
Visibility,
34+
self, AdtKind, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt, Visibility,
3635
};
3736
use rustc_session::config::{self, DebugInfo, Lto};
3837
use rustc_span::symbol::Symbol;
@@ -1082,7 +1081,7 @@ fn build_upvar_field_di_nodes<'ll, 'tcx>(
10821081
closure_or_coroutine_di_node: &'ll DIType,
10831082
) -> SmallVec<&'ll DIType> {
10841083
let (&def_id, up_var_tys) = match closure_or_coroutine_ty.kind() {
1085-
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().prefix_tys()),
1084+
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().upvar_tys()),
10861085
ty::Closure(def_id, args) => (def_id, args.as_closure().upvar_tys()),
10871086
ty::CoroutineClosure(def_id, args) => (def_id, args.as_coroutine_closure().upvar_tys()),
10881087
_ => {

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/cpp_like.rs

-2
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
686686
let coroutine_layout =
687687
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();
688688

689-
let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
690689
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
691690
let variant_count = (variant_range.start.as_u32()..variant_range.end.as_u32()).len();
692691

@@ -721,7 +720,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
721720
coroutine_type_and_layout,
722721
coroutine_type_di_node,
723722
coroutine_layout,
724-
common_upvar_names,
725723
);
726724

727725
let span = coroutine_layout.variant_source_info[variant_index].span;

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs

+2-31
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use rustc_codegen_ssa::debuginfo::{
33
wants_c_like_enum_debuginfo,
44
};
55
use rustc_hir::def::CtorKind;
6-
use rustc_index::IndexSlice;
76
use rustc_middle::{
87
bug,
98
mir::CoroutineLayout,
@@ -13,7 +12,6 @@ use rustc_middle::{
1312
AdtDef, CoroutineArgs, CoroutineArgsExt, Ty, VariantDef,
1413
},
1514
};
16-
use rustc_span::Symbol;
1715
use rustc_target::abi::{
1816
FieldIdx, HasDataLayout, Integer, Primitive, TagEncoding, VariantIdx, Variants,
1917
};
@@ -323,7 +321,6 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
323321
coroutine_type_and_layout: TyAndLayout<'tcx>,
324322
coroutine_type_di_node: &'ll DIType,
325323
coroutine_layout: &CoroutineLayout<'tcx>,
326-
common_upvar_names: &IndexSlice<FieldIdx, Symbol>,
327324
) -> &'ll DIType {
328325
let variant_name = CoroutineArgs::variant_name(variant_index);
329326
let unique_type_id = UniqueTypeId::for_enum_variant_struct_type(
@@ -334,11 +331,6 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
334331

335332
let variant_layout = coroutine_type_and_layout.for_variant(cx, variant_index);
336333

337-
let coroutine_args = match coroutine_type_and_layout.ty.kind() {
338-
ty::Coroutine(_, args) => args.as_coroutine(),
339-
_ => unreachable!(),
340-
};
341-
342334
type_map::build_type_with_children(
343335
cx,
344336
type_map::stub(
@@ -352,7 +344,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
352344
),
353345
|cx, variant_struct_type_di_node| {
354346
// Fields that just belong to this variant/state
355-
let state_specific_fields: SmallVec<_> = (0..variant_layout.fields.count())
347+
(0..variant_layout.fields.count())
356348
.map(|field_index| {
357349
let coroutine_saved_local = coroutine_layout.variant_fields[variant_index]
358350
[FieldIdx::from_usize(field_index)];
@@ -374,28 +366,7 @@ pub fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
374366
type_di_node(cx, field_type),
375367
)
376368
})
377-
.collect();
378-
379-
// Fields that are common to all states
380-
let common_fields: SmallVec<_> = coroutine_args
381-
.prefix_tys()
382-
.iter()
383-
.zip(common_upvar_names)
384-
.enumerate()
385-
.map(|(index, (upvar_ty, upvar_name))| {
386-
build_field_di_node(
387-
cx,
388-
variant_struct_type_di_node,
389-
upvar_name.as_str(),
390-
cx.size_and_align_of(upvar_ty),
391-
coroutine_type_and_layout.fields.offset(index),
392-
DIFlags::FlagZero,
393-
type_di_node(cx, upvar_ty),
394-
)
395-
})
396-
.collect();
397-
398-
state_specific_fields.into_iter().chain(common_fields).collect()
369+
.collect()
399370
},
400371
|cx| build_generic_type_param_di_nodes(cx, coroutine_type_and_layout.ty),
401372
)

compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/native.rs

-4
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
172172
)
173173
};
174174

175-
let common_upvar_names =
176-
cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
177-
178175
// Build variant struct types
179176
let variant_struct_type_di_nodes: SmallVec<_> = variants
180177
.indices()
@@ -202,7 +199,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
202199
coroutine_type_and_layout,
203200
coroutine_type_di_node,
204201
coroutine_layout,
205-
common_upvar_names,
206202
),
207203
source_info,
208204
}

compiler/rustc_codegen_ssa/src/mir/operand.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::fmt;
1616

1717
use arrayvec::ArrayVec;
1818
use either::Either;
19-
use tracing::debug;
19+
use tracing::{debug, instrument};
2020

2121
/// The representation of a Rust value. The enum variant is in fact
2222
/// uniquely determined by the value's type, but is kept as a
@@ -552,13 +552,12 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {
552552
}
553553

554554
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
555+
#[instrument(level = "debug", skip(self, bx), ret)]
555556
fn maybe_codegen_consume_direct(
556557
&mut self,
557558
bx: &mut Bx,
558559
place_ref: mir::PlaceRef<'tcx>,
559560
) -> Option<OperandRef<'tcx, Bx::Value>> {
560-
debug!("maybe_codegen_consume_direct(place_ref={:?})", place_ref);
561-
562561
match self.locals[place_ref.local] {
563562
LocalRef::Operand(mut o) => {
564563
// Moves out of scalar and scalar pair fields are trivial.
@@ -601,13 +600,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
601600
}
602601
}
603602

603+
#[instrument(level = "debug", skip(self, bx), ret)]
604604
pub fn codegen_consume(
605605
&mut self,
606606
bx: &mut Bx,
607607
place_ref: mir::PlaceRef<'tcx>,
608608
) -> OperandRef<'tcx, Bx::Value> {
609-
debug!("codegen_consume(place_ref={:?})", place_ref);
610-
611609
let ty = self.monomorphized_place_ty(place_ref);
612610
let layout = bx.cx().layout_of(ty);
613611

@@ -626,13 +624,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
626624
bx.load_operand(place)
627625
}
628626

627+
#[instrument(level = "debug", skip(self, bx), ret)]
629628
pub fn codegen_operand(
630629
&mut self,
631630
bx: &mut Bx,
632631
operand: &mir::Operand<'tcx>,
633632
) -> OperandRef<'tcx, Bx::Value> {
634-
debug!("codegen_operand(operand={:?})", operand);
635-
636633
match *operand {
637634
mir::Operand::Copy(ref place) | mir::Operand::Move(ref place) => {
638635
self.codegen_consume(bx, place.as_ref())

compiler/rustc_codegen_ssa/src/mir/rvalue.rs

+3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
131131
let variant_dest = dest.project_downcast(bx, variant_index);
132132
(variant_index, variant_dest, active_field_index)
133133
}
134+
mir::AggregateKind::Coroutine(_, _) => {
135+
(FIRST_VARIANT, dest.project_downcast(bx, FIRST_VARIANT), None)
136+
}
134137
_ => (FIRST_VARIANT, dest, None),
135138
};
136139
if active_field_index.is_some() {

compiler/rustc_const_eval/src/interpret/step.rs

+3
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
304304
let variant_dest = self.project_downcast(dest, variant_index)?;
305305
(variant_index, variant_dest, active_field_index)
306306
}
307+
mir::AggregateKind::Coroutine(_def_id, _args) => {
308+
(FIRST_VARIANT, self.project_downcast(dest, FIRST_VARIANT)?, None)
309+
}
307310
mir::AggregateKind::RawPtr(..) => {
308311
// Pointers don't have "fields" in the normal sense, so the
309312
// projection-based code below would either fail in projection

compiler/rustc_middle/src/mir/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,9 @@ pub struct Body<'tcx> {
443443
/// If `-Cinstrument-coverage` is not active, or if an individual function
444444
/// is not eligible for coverage, then this should always be `None`.
445445
pub function_coverage_info: Option<Box<coverage::FunctionCoverageInfo>>,
446+
447+
/// Coroutine local-upvar map
448+
pub local_upvar_map: IndexVec<FieldIdx, Option<Local>>,
446449
}
447450

448451
impl<'tcx> Body<'tcx> {
@@ -486,6 +489,7 @@ impl<'tcx> Body<'tcx> {
486489
tainted_by_errors,
487490
coverage_branch_info: None,
488491
function_coverage_info: None,
492+
local_upvar_map: IndexVec::new(),
489493
};
490494
body.is_polymorphic = body.has_non_region_param();
491495
body
@@ -517,6 +521,7 @@ impl<'tcx> Body<'tcx> {
517521
tainted_by_errors: None,
518522
coverage_branch_info: None,
519523
function_coverage_info: None,
524+
local_upvar_map: IndexVec::new(),
520525
};
521526
body.is_polymorphic = body.has_non_region_param();
522527
body

compiler/rustc_middle/src/mir/patch.rs

+4
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,14 @@ impl<'tcx> MirPatch<'tcx> {
155155
ty: Ty<'tcx>,
156156
span: Span,
157157
local_info: LocalInfo<'tcx>,
158+
immutable: bool,
158159
) -> Local {
159160
let index = self.next_local;
160161
self.next_local += 1;
161162
let mut new_decl = LocalDecl::new(ty, span);
163+
if immutable {
164+
new_decl = new_decl.immutable();
165+
}
162166
**new_decl.local_info.as_mut().assert_crate_local() = local_info;
163167
self.new_locals.push(new_decl);
164168
Local::new(index)

compiler/rustc_middle/src/ty/layout.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -906,9 +906,10 @@ where
906906
),
907907
Variants::Multiple { tag, tag_field, .. } => {
908908
if i == tag_field {
909-
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
909+
TyMaybeWithLayout::TyAndLayout(tag_layout(tag))
910+
} else {
911+
TyMaybeWithLayout::Ty(args.as_coroutine().upvar_tys()[i])
910912
}
911-
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
912913
}
913914
},
914915

compiler/rustc_middle/src/ty/sty.rs

-7
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,6 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
149149
})
150150
})
151151
}
152-
153-
/// This is the types of the fields of a coroutine which are not stored in a
154-
/// variant.
155-
#[inline]
156-
fn prefix_tys(self) -> &'tcx List<Ty<'tcx>> {
157-
self.upvar_tys()
158-
}
159152
}
160153

161154
#[derive(Debug, Copy, Clone, HashStable, TypeFoldable, TypeVisitable)]

0 commit comments

Comments
 (0)