Skip to content

Commit f700fb2

Browse files
Rollup merge of #123349 - compiler-errors:async-closure-captures, r=oli-obk
Fix capture analysis for by-move closure bodies The check we were doing to figure out if a coroutine was borrowing from its parent coroutine-closure was flat-out wrong -- a misunderstanding of mine of the way that `tcx.closure_captures` represents its captures. Fixes #123251 (the miri/ui test I added should more than cover that issue) r? `@oli-obk` -- I recognize that this PR may be underdocumented, so please ask me what I should explain further.
2 parents abb0393 + ec74a30 commit f700fb2

File tree

5 files changed

+311
-31
lines changed

5 files changed

+311
-31
lines changed

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

+118-31
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,66 @@
1-
//! A MIR pass which duplicates a coroutine's body and removes any derefs which
2-
//! would be present for upvars that are taken by-ref. The result of which will
3-
//! be a coroutine body that takes all of its upvars by-move, and which we stash
4-
//! into the `CoroutineInfo` for all coroutines returned by coroutine-closures.
1+
//! This pass constructs a second coroutine body sufficient for return from
2+
//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures).
3+
//!
4+
//! Consider an async closure like:
5+
//! ```rust
6+
//! #![feature(async_closure)]
7+
//!
8+
//! let x = vec![1, 2, 3];
9+
//!
10+
//! let closure = async move || {
11+
//! println!("{x:#?}");
12+
//! };
13+
//! ```
14+
//!
15+
//! This desugars to something like:
16+
//! ```rust,ignore (invalid-borrowck)
17+
//! let x = vec![1, 2, 3];
18+
//!
19+
//! let closure = move || {
20+
//! async {
21+
//! println!("{x:#?}");
22+
//! }
23+
//! };
24+
//! ```
25+
//!
26+
//! Important to note here is that while the outer closure *moves* `x: Vec<i32>`
27+
//! into its upvars, the inner `async` coroutine simply captures a ref of `x`.
28+
//! This is the "magic" of async closures -- the futures that they return are
29+
//! allowed to borrow from their parent closure's upvars.
30+
//!
31+
//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`,
32+
//! since all async closures implement that too)? Well, recall the signature:
33+
//! ```
34+
//! use std::future::Future;
35+
//! pub trait AsyncFnOnce<Args>
36+
//! {
37+
//! type CallOnceFuture: Future<Output = Self::Output>;
38+
//! type Output;
39+
//! fn async_call_once(
40+
//! self,
41+
//! args: Args
42+
//! ) -> Self::CallOnceFuture;
43+
//! }
44+
//! ```
45+
//!
46+
//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`.
47+
//! How do we deal with the fact that the coroutine is supposed to take a reference
48+
//! to the captured `x` from the parent closure, when that parent closure has been
49+
//! destroyed?
50+
//!
51+
//! This is the second piece of magic of async closures. We can simply create a
52+
//! *second* `async` coroutine body where that `x` that was previously captured
53+
//! by reference is now captured by value. This means that we consume the outer
54+
//! closure and return a new coroutine that will hold onto all of these captures,
55+
//! and drop them when it is finished (i.e. after it has been `.await`ed).
56+
//!
57+
//! We do this with the analysis below, which detects the captures that come from
58+
//! borrowing from the outer closure, and we simply peel off a `deref` projection
59+
//! from them. This second body is stored alongside the first body, and optimized
60+
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
61+
//! we use this "by move" body instead.
62+
63+
use itertools::Itertools;
564

665
use rustc_data_structures::unord::UnordSet;
766
use rustc_hir as hir;
@@ -14,6 +73,8 @@ pub struct ByMoveBody;
1473

1574
impl<'tcx> MirPass<'tcx> for ByMoveBody {
1675
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
76+
// We only need to generate by-move coroutine bodies for coroutines that come
77+
// from coroutine-closures.
1778
let Some(coroutine_def_id) = body.source.def_id().as_local() else {
1879
return;
1980
};
@@ -22,44 +83,70 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
2283
else {
2384
return;
2485
};
86+
87+
// Also, let's skip processing any bodies with errors, since there's no guarantee
88+
// the MIR body will be constructed well.
2589
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
2690
if coroutine_ty.references_error() {
2791
return;
2892
}
29-
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
3093

31-
let coroutine_kind = args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
94+
let ty::Coroutine(_, coroutine_args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
95+
// We don't need to generate a by-move coroutine if the kind of the coroutine is
96+
// already `FnOnce` -- that means that any upvars that the closure consumes have
97+
// already been taken by-value.
98+
let coroutine_kind = coroutine_args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap();
3299
if coroutine_kind == ty::ClosureKind::FnOnce {
33100
return;
34101
}
35102

103+
let parent_def_id = tcx.local_parent(coroutine_def_id);
104+
let ty::CoroutineClosure(_, parent_args) =
105+
*tcx.type_of(parent_def_id).instantiate_identity().kind()
106+
else {
107+
bug!();
108+
};
109+
let parent_closure_args = parent_args.as_coroutine_closure();
110+
let num_args = parent_closure_args
111+
.coroutine_closure_sig()
112+
.skip_binder()
113+
.tupled_inputs_ty
114+
.tuple_fields()
115+
.len();
116+
36117
let mut by_ref_fields = UnordSet::default();
37-
let by_move_upvars = Ty::new_tup_from_iter(
38-
tcx,
39-
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
40-
if capture.is_by_ref() {
41-
by_ref_fields.insert(FieldIdx::from_usize(idx));
42-
}
43-
capture.place.ty()
44-
}),
45-
);
46-
let by_move_coroutine_ty = Ty::new_coroutine(
47-
tcx,
48-
coroutine_def_id.to_def_id(),
49-
ty::CoroutineArgs::new(
118+
for (idx, (coroutine_capture, parent_capture)) in tcx
119+
.closure_captures(coroutine_def_id)
120+
.iter()
121+
// By construction we capture all the args first.
122+
.skip(num_args)
123+
.zip_eq(tcx.closure_captures(parent_def_id))
124+
.enumerate()
125+
{
126+
// This upvar is captured by-move from the parent closure, but by-ref
127+
// from the inner async block. That means that it's being borrowed from
128+
// the outer closure body -- we need to change the coroutine to take the
129+
// upvar by value.
130+
if coroutine_capture.is_by_ref() && !parent_capture.is_by_ref() {
131+
by_ref_fields.insert(FieldIdx::from_usize(num_args + idx));
132+
}
133+
134+
// Make sure we're actually talking about the same capture.
135+
// FIXME(async_closures): We could look at the `hir::Upvar` instead?
136+
assert_eq!(coroutine_capture.place.ty(), parent_capture.place.ty());
137+
}
138+
139+
let by_move_coroutine_ty = tcx
140+
.instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig())
141+
.to_coroutine_given_kind_and_upvars(
50142
tcx,
51-
ty::CoroutineArgsParts {
52-
parent_args: args.as_coroutine().parent_args(),
53-
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
54-
resume_ty: args.as_coroutine().resume_ty(),
55-
yield_ty: args.as_coroutine().yield_ty(),
56-
return_ty: args.as_coroutine().return_ty(),
57-
witness: args.as_coroutine().witness(),
58-
tupled_upvars_ty: by_move_upvars,
59-
},
60-
)
61-
.args,
62-
);
143+
parent_closure_args.parent_args(),
144+
coroutine_def_id.to_def_id(),
145+
ty::ClosureKind::FnOnce,
146+
tcx.lifetimes.re_erased,
147+
parent_closure_args.tupled_upvars_ty(),
148+
parent_closure_args.coroutine_captures_by_ref_ty(),
149+
);
63150

64151
let mut by_move_body = body.clone();
65152
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Same as rustc's `tests/ui/async-await/async-closures/captures.rs`, keep in sync
2+
3+
#![feature(async_closure, noop_waker)]
4+
5+
use std::future::Future;
6+
use std::pin::pin;
7+
use std::task::*;
8+
9+
pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
10+
let mut fut = pin!(fut);
11+
let ctx = &mut Context::from_waker(Waker::noop());
12+
13+
loop {
14+
match fut.as_mut().poll(ctx) {
15+
Poll::Pending => {}
16+
Poll::Ready(t) => break t,
17+
}
18+
}
19+
}
20+
21+
fn main() {
22+
block_on(async_main());
23+
}
24+
25+
async fn call<T>(f: &impl async Fn() -> T) -> T {
26+
f().await
27+
}
28+
29+
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
30+
f().await
31+
}
32+
33+
#[derive(Debug)]
34+
#[allow(unused)]
35+
struct Hello(i32);
36+
37+
async fn async_main() {
38+
// Capture something by-ref
39+
{
40+
let x = Hello(0);
41+
let c = async || {
42+
println!("{x:?}");
43+
};
44+
call(&c).await;
45+
call_once(c).await;
46+
47+
let x = &Hello(1);
48+
let c = async || {
49+
println!("{x:?}");
50+
};
51+
call(&c).await;
52+
call_once(c).await;
53+
}
54+
55+
// Capture something and consume it (force to `AsyncFnOnce`)
56+
{
57+
let x = Hello(2);
58+
let c = async || {
59+
println!("{x:?}");
60+
drop(x);
61+
};
62+
call_once(c).await;
63+
}
64+
65+
// Capture something with `move`, don't consume it
66+
{
67+
let x = Hello(3);
68+
let c = async move || {
69+
println!("{x:?}");
70+
};
71+
call(&c).await;
72+
call_once(c).await;
73+
74+
let x = &Hello(4);
75+
let c = async move || {
76+
println!("{x:?}");
77+
};
78+
call(&c).await;
79+
call_once(c).await;
80+
}
81+
82+
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
83+
{
84+
let x = Hello(5);
85+
let c = async move || {
86+
println!("{x:?}");
87+
drop(x);
88+
};
89+
call_once(c).await;
90+
}
91+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Hello(0)
2+
Hello(0)
3+
Hello(1)
4+
Hello(1)
5+
Hello(2)
6+
Hello(3)
7+
Hello(3)
8+
Hello(4)
9+
Hello(4)
10+
Hello(5)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//@ aux-build:block-on.rs
2+
//@ edition:2021
3+
//@ run-pass
4+
//@ check-run-results
5+
6+
// Same as miri's `tests/pass/async-closure-captures.rs`, keep in sync
7+
8+
#![feature(async_closure)]
9+
10+
extern crate block_on;
11+
12+
fn main() {
13+
block_on::block_on(async_main());
14+
}
15+
16+
async fn call<T>(f: &impl async Fn() -> T) -> T {
17+
f().await
18+
}
19+
20+
async fn call_once<T>(f: impl async FnOnce() -> T) -> T {
21+
f().await
22+
}
23+
24+
#[derive(Debug)]
25+
#[allow(unused)]
26+
struct Hello(i32);
27+
28+
async fn async_main() {
29+
// Capture something by-ref
30+
{
31+
let x = Hello(0);
32+
let c = async || {
33+
println!("{x:?}");
34+
};
35+
call(&c).await;
36+
call_once(c).await;
37+
38+
let x = &Hello(1);
39+
let c = async || {
40+
println!("{x:?}");
41+
};
42+
call(&c).await;
43+
call_once(c).await;
44+
}
45+
46+
// Capture something and consume it (force to `AsyncFnOnce`)
47+
{
48+
let x = Hello(2);
49+
let c = async || {
50+
println!("{x:?}");
51+
drop(x);
52+
};
53+
call_once(c).await;
54+
}
55+
56+
// Capture something with `move`, don't consume it
57+
{
58+
let x = Hello(3);
59+
let c = async move || {
60+
println!("{x:?}");
61+
};
62+
call(&c).await;
63+
call_once(c).await;
64+
65+
let x = &Hello(4);
66+
let c = async move || {
67+
println!("{x:?}");
68+
};
69+
call(&c).await;
70+
call_once(c).await;
71+
}
72+
73+
// Capture something with `move`, also consume it (so `AsyncFnOnce`)
74+
{
75+
let x = Hello(5);
76+
let c = async move || {
77+
println!("{x:?}");
78+
drop(x);
79+
};
80+
call_once(c).await;
81+
}
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Hello(0)
2+
Hello(0)
3+
Hello(1)
4+
Hello(1)
5+
Hello(2)
6+
Hello(3)
7+
Hello(3)
8+
Hello(4)
9+
Hello(4)
10+
Hello(5)

0 commit comments

Comments
 (0)