@@ -6,14 +6,15 @@ use rustc_hir::def_id::DefId;
6
6
use rustc_hir:: lang_items:: LangItem ;
7
7
use rustc_index:: { Idx , IndexVec } ;
8
8
use rustc_middle:: mir:: patch:: MirPatch ;
9
+ use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext } ;
9
10
use rustc_middle:: mir:: * ;
10
11
use rustc_middle:: query:: Providers ;
11
12
use rustc_middle:: ty:: {
12
13
self , CoroutineArgs , CoroutineArgsExt , EarlyBinder , GenericArgs , Ty , TyCtxt ,
13
14
} ;
14
15
use rustc_middle:: { bug, span_bug} ;
15
16
use rustc_mir_dataflow:: elaborate_drops:: { self , DropElaborator , DropFlagMode , DropStyle } ;
16
- use rustc_span:: source_map:: Spanned ;
17
+ use rustc_span:: source_map:: { dummy_spanned , Spanned } ;
17
18
use rustc_span:: { Span , DUMMY_SP } ;
18
19
use rustc_target:: abi:: { FieldIdx , VariantIdx , FIRST_VARIANT } ;
19
20
use rustc_target:: spec:: abi:: Abi ;
@@ -24,10 +25,46 @@ use crate::{
24
25
instsimplify, mentioned_items, pass_manager as pm, remove_noop_landing_pads, simplify,
25
26
} ;
26
27
28
+ mod async_destructor_ctor;
29
+
27
30
pub fn provide ( providers : & mut Providers ) {
28
31
providers. mir_shims = make_shim;
29
32
}
30
33
34
+ // Replace Pin<&mut ImplCoroutine> accesses (_1.0) into Pin<&mut ProxyCoroutine> acceses
35
+ struct FixProxyFutureDropVisitor < ' tcx > {
36
+ tcx : TyCtxt < ' tcx > ,
37
+ replace_to : Local ,
38
+ }
39
+
40
+ impl < ' tcx > MutVisitor < ' tcx > for FixProxyFutureDropVisitor < ' tcx > {
41
+ fn tcx ( & self ) -> TyCtxt < ' tcx > {
42
+ self . tcx
43
+ }
44
+
45
+ fn visit_place (
46
+ & mut self ,
47
+ place : & mut Place < ' tcx > ,
48
+ _context : PlaceContext ,
49
+ _location : Location ,
50
+ ) {
51
+ if place. local == Local :: from_u32 ( 1 ) {
52
+ if place. projection . len ( ) == 1 {
53
+ assert ! ( matches!(
54
+ place. projection. first( ) ,
55
+ Some ( ProjectionElem :: Field ( FieldIdx :: ZERO , _) )
56
+ ) ) ;
57
+ * place = Place :: from ( self . replace_to ) ;
58
+ } else if place. projection . len ( ) == 2 {
59
+ assert ! ( matches!( place. projection[ 0 ] , ProjectionElem :: Field ( FieldIdx :: ZERO , _) ) ) ;
60
+ assert ! ( matches!( place. projection[ 1 ] , ProjectionElem :: Deref ) ) ;
61
+ * place =
62
+ Place :: from ( self . replace_to ) . project_deeper ( & [ ProjectionElem :: Deref ] , self . tcx ) ;
63
+ }
64
+ }
65
+ }
66
+ }
67
+
31
68
fn make_shim < ' tcx > ( tcx : TyCtxt < ' tcx > , instance : ty:: InstanceKind < ' tcx > ) -> Body < ' tcx > {
32
69
debug ! ( "make_shim({:?})" , instance) ;
33
70
@@ -127,14 +164,218 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<
127
164
ty:: InstanceKind :: ThreadLocalShim ( ..) => build_thread_local_shim ( tcx, instance) ,
128
165
ty:: InstanceKind :: CloneShim ( def_id, ty) => build_clone_shim ( tcx, def_id, ty) ,
129
166
ty:: InstanceKind :: FnPtrAddrShim ( def_id, ty) => build_fn_ptr_addr_shim ( tcx, def_id, ty) ,
130
- ty:: InstanceKind :: FutureDropPollShim ( _def_id, _proxy_ty, _impl_ty) => {
131
- todo ! ( )
167
+ ty:: InstanceKind :: FutureDropPollShim ( def_id, proxy_ty, impl_ty) => {
168
+ let ty:: Coroutine ( coroutine_def_id, impl_args) = impl_ty. kind ( ) else {
169
+ bug ! ( "FutureDropPollShim not for coroutine impl type: ({:?})" , instance) ;
170
+ } ;
171
+
172
+ let span = tcx. def_span ( def_id) ;
173
+ let source_info = SourceInfo :: outermost ( span) ;
174
+
175
+ let pin_proxy_layout_local = Local :: new ( 1 ) ;
176
+ let cor_ref = Ty :: new_mut_ref ( tcx, tcx. lifetimes . re_erased , impl_ty) ;
177
+ let proxy_ref = Ty :: new_mut_ref ( tcx, tcx. lifetimes . re_erased , proxy_ty) ;
178
+ // taking _1.0.0 (impl from Pin, impl from proxy)
179
+ let proxy_ref_place = Place :: from ( pin_proxy_layout_local)
180
+ . project_deeper ( & [ PlaceElem :: Field ( FieldIdx :: ZERO , proxy_ref) ] , tcx) ;
181
+ let impl_ref_place = |proxy_ref_local : Local | {
182
+ Place :: from ( proxy_ref_local) . project_deeper (
183
+ & [
184
+ PlaceElem :: Deref ,
185
+ PlaceElem :: Downcast ( None , VariantIdx :: ZERO ) ,
186
+ PlaceElem :: Field ( FieldIdx :: ZERO , cor_ref) ,
187
+ ] ,
188
+ tcx,
189
+ )
190
+ } ;
191
+
192
+ if tcx. is_templated_coroutine ( * coroutine_def_id) {
193
+ // ret_ty = `Poll<()>`
194
+ let poll_adt_ref = tcx. adt_def ( tcx. require_lang_item ( LangItem :: Poll , None ) ) ;
195
+ let ret_ty = Ty :: new_adt ( tcx, poll_adt_ref, tcx. mk_args ( & [ tcx. types . unit . into ( ) ] ) ) ;
196
+ // env_ty = `Pin<&mut proxy_ty>`
197
+ let pin_adt_ref = tcx. adt_def ( tcx. require_lang_item ( LangItem :: Pin , None ) ) ;
198
+ let env_ty = Ty :: new_adt ( tcx, pin_adt_ref, tcx. mk_args ( & [ proxy_ref. into ( ) ] ) ) ;
199
+ // sig = `fn (Pin<&mut proxy_ty>, &mut Context) -> Poll<()>`
200
+ let sig = tcx. mk_fn_sig (
201
+ [ env_ty, Ty :: new_task_context ( tcx) ] ,
202
+ ret_ty,
203
+ false ,
204
+ hir:: Safety :: Safe ,
205
+ rustc_target:: spec:: abi:: Abi :: Rust ,
206
+ ) ;
207
+ let mut locals = local_decls_for_sig ( & sig, span) ;
208
+ let mut blocks = IndexVec :: with_capacity ( 3 ) ;
209
+
210
+ let proxy_ref_local = locals. push ( LocalDecl :: new ( proxy_ref, span) ) ;
211
+ let cor_ref_local = locals. push ( LocalDecl :: new ( cor_ref, span) ) ;
212
+ let cor_ref_place = Place :: from ( cor_ref_local) ;
213
+
214
+ let call_bb = BasicBlock :: new ( 1 ) ;
215
+ let return_bb = BasicBlock :: new ( 2 ) ;
216
+
217
+ let assign1 = Statement {
218
+ source_info,
219
+ kind : StatementKind :: Assign ( Box :: new ( (
220
+ Place :: from ( proxy_ref_local) ,
221
+ Rvalue :: CopyForDeref ( proxy_ref_place) ,
222
+ ) ) ) ,
223
+ } ;
224
+ let assign2 = Statement {
225
+ source_info,
226
+ kind : StatementKind :: Assign ( Box :: new ( (
227
+ cor_ref_place,
228
+ Rvalue :: CopyForDeref ( impl_ref_place ( proxy_ref_local) ) ,
229
+ ) ) ) ,
230
+ } ;
231
+
232
+ // cor_pin_ty = `Pin<&mut cor_ref>`
233
+ let cor_pin_ty = Ty :: new_adt ( tcx, pin_adt_ref, tcx. mk_args ( & [ cor_ref. into ( ) ] ) ) ;
234
+ let cor_pin_place = Place :: from ( locals. push ( LocalDecl :: new ( cor_pin_ty, span) ) ) ;
235
+
236
+ let pin_fn = tcx. require_lang_item ( LangItem :: PinNewUnchecked , Some ( span) ) ;
237
+ // call Pin<FutTy>::new_unchecked(&mut impl_cor)
238
+ blocks. push ( BasicBlockData {
239
+ statements : vec ! [ assign1, assign2] ,
240
+ terminator : Some ( Terminator {
241
+ source_info,
242
+ kind : TerminatorKind :: Call {
243
+ func : Operand :: function_handle ( tcx, pin_fn, [ cor_ref. into ( ) ] , span) ,
244
+ args : [ dummy_spanned ( Operand :: Move ( cor_ref_place) ) ] . into ( ) ,
245
+ destination : cor_pin_place,
246
+ target : Some ( call_bb) ,
247
+ unwind : UnwindAction :: Continue ,
248
+ call_source : CallSource :: Misc ,
249
+ fn_span : span,
250
+ } ,
251
+ } ) ,
252
+ is_cleanup : false ,
253
+ } ) ;
254
+ // When dropping async drop coroutine, we continue its execution:
255
+ // we call impl::poll (impl_layout, ctx)
256
+ let poll_fn = tcx. require_lang_item ( LangItem :: FuturePoll , None ) ;
257
+ let resume_ctx = Place :: from ( Local :: new ( 2 ) ) ;
258
+ blocks. push ( BasicBlockData {
259
+ statements : vec ! [ ] ,
260
+ terminator : Some ( Terminator {
261
+ source_info,
262
+ kind : TerminatorKind :: Call {
263
+ func : Operand :: function_handle ( tcx, poll_fn, [ impl_ty. into ( ) ] , span) ,
264
+ args : [
265
+ dummy_spanned ( Operand :: Move ( cor_pin_place) ) ,
266
+ dummy_spanned ( Operand :: Move ( resume_ctx) ) ,
267
+ ]
268
+ . into ( ) ,
269
+ destination : Place :: return_place ( ) ,
270
+ target : Some ( return_bb) ,
271
+ unwind : UnwindAction :: Continue ,
272
+ call_source : CallSource :: Misc ,
273
+ fn_span : span,
274
+ } ,
275
+ } ) ,
276
+ is_cleanup : false ,
277
+ } ) ;
278
+ blocks. push ( BasicBlockData {
279
+ statements : vec ! [ ] ,
280
+ terminator : Some ( Terminator { source_info, kind : TerminatorKind :: Return } ) ,
281
+ is_cleanup : false ,
282
+ } ) ;
283
+
284
+ let source = MirSource :: from_instance ( instance) ;
285
+ let mut body = new_body ( source, blocks, locals, sig. inputs ( ) . len ( ) , span) ;
286
+ pm:: run_passes (
287
+ tcx,
288
+ & mut body,
289
+ & [
290
+ & mentioned_items:: MentionedItems ,
291
+ & abort_unwinding_calls:: AbortUnwindingCalls ,
292
+ & add_call_guards:: CriticalCallEdges ,
293
+ ] ,
294
+ Some ( MirPhase :: Runtime ( RuntimePhase :: Optimized ) ) ,
295
+ ) ;
296
+ return body;
297
+ }
298
+ // future drop poll for async drop must be resolved to standart poll (AsyncDropGlue)
299
+ assert ! ( !tcx. is_templated_coroutine( * coroutine_def_id) ) ;
300
+
301
+ // converting `(_1: Pin<&mut CorLayout>, _2: &mut Context<'_>) -> Poll<()>`
302
+ // into `(_1: Pin<&mut ProxyLayout>, _2: &mut Context<'_>) -> Poll<()>`
303
+ // let mut _x: &mut CorLayout = &*_1.0.0;
304
+ // Replace old _1.0 accesses into _x accesses;
305
+ let body = tcx. optimized_mir ( * coroutine_def_id) . future_drop_poll ( ) . unwrap ( ) ;
306
+ let mut body: Body < ' tcx > = EarlyBinder :: bind ( body. clone ( ) ) . instantiate ( tcx, impl_args) ;
307
+ body. source . instance = instance;
308
+ body. var_debug_info . clear ( ) ;
309
+ let pin_adt_ref = tcx. adt_def ( tcx. require_lang_item ( LangItem :: Pin , Some ( span) ) ) ;
310
+ let args = tcx. mk_args ( & [ proxy_ref. into ( ) ] ) ;
311
+ let pin_proxy_ref = Ty :: new_adt ( tcx, pin_adt_ref, args) ;
312
+
313
+ let proxy_ref_local = body. local_decls . push ( LocalDecl :: new ( proxy_ref, span) ) ;
314
+ let cor_ref_local = body. local_decls . push ( LocalDecl :: new ( cor_ref, span) ) ;
315
+ FixProxyFutureDropVisitor { tcx, replace_to : cor_ref_local } . visit_body ( & mut body) ;
316
+ // Now changing first arg from Pin<&mut ImplCoroutine> to Pin<&mut ProxyCoroutine>
317
+ body. local_decls [ pin_proxy_layout_local] = LocalDecl :: new ( pin_proxy_ref, span) ;
318
+
319
+ {
320
+ let bb: & mut BasicBlockData < ' tcx > = & mut body. basic_blocks_mut ( ) [ START_BLOCK ] ;
321
+ // _tmp = _1.0 : Pin<&ProxyLayout> ==> &ProxyLayout
322
+ bb. statements . insert (
323
+ 0 ,
324
+ Statement {
325
+ source_info,
326
+ kind : StatementKind :: Assign ( Box :: new ( (
327
+ Place :: from ( proxy_ref_local) ,
328
+ Rvalue :: CopyForDeref ( proxy_ref_place) ,
329
+ ) ) ) ,
330
+ } ,
331
+ ) ;
332
+ bb. statements . insert (
333
+ 1 ,
334
+ Statement {
335
+ source_info,
336
+ kind : StatementKind :: Assign ( Box :: new ( (
337
+ Place :: from ( cor_ref_local) ,
338
+ Rvalue :: CopyForDeref ( impl_ref_place ( proxy_ref_local) ) ,
339
+ ) ) ) ,
340
+ } ,
341
+ ) ;
342
+ }
343
+
344
+ pm:: run_passes (
345
+ tcx,
346
+ & mut body,
347
+ & [
348
+ & mentioned_items:: MentionedItems ,
349
+ & abort_unwinding_calls:: AbortUnwindingCalls ,
350
+ & add_call_guards:: CriticalCallEdges ,
351
+ ] ,
352
+ Some ( MirPhase :: Runtime ( RuntimePhase :: Optimized ) ) ,
353
+ ) ;
354
+ debug ! ( "make_shim({:?}) = {:?}" , instance, body) ;
355
+ return body;
132
356
}
133
- ty:: InstanceKind :: AsyncDropGlue ( _def_id, _ty) => {
134
- todo ! ( )
357
+ ty:: InstanceKind :: AsyncDropGlue ( def_id, ty) => {
358
+ let mut body = async_destructor_ctor:: build_async_drop_shim ( tcx, def_id, ty) ;
359
+
360
+ pm:: run_passes (
361
+ tcx,
362
+ & mut body,
363
+ & [
364
+ & mentioned_items:: MentionedItems ,
365
+ & simplify:: SimplifyCfg :: MakeShim ,
366
+ & crate :: reveal_all:: RevealAll ,
367
+ & crate :: coroutine:: StateTransform ,
368
+ ] ,
369
+ Some ( MirPhase :: Runtime ( RuntimePhase :: PostCleanup ) ) ,
370
+ ) ;
371
+ debug ! ( "make_shim({:?}) = {:?}" , instance, body) ;
372
+ return body;
135
373
}
136
- ty:: InstanceKind :: AsyncDropGlueCtorShim ( _def_id, _ty) => {
137
- bug ! ( "AsyncDropGlueCtorShim in re-working ({:?})" , instance)
374
+
375
+ ty:: InstanceKind :: AsyncDropGlueCtorShim ( def_id, ty) => {
376
+ let body = async_destructor_ctor:: build_async_destructor_ctor_shim ( tcx, def_id, ty) ;
377
+ debug ! ( "make_shim({:?}) = {:?}" , instance, body) ;
378
+ return body;
138
379
}
139
380
ty:: InstanceKind :: Virtual ( ..) => {
140
381
bug ! ( "InstanceKind::Virtual ({:?}) is for direct calls only" , instance)
0 commit comments