@@ -513,3 +513,249 @@ llvm::Value *IRGenFunction::alignUpToMaximumAlignment(llvm::Type *sizeTy, llvm::
513
513
auto *invertedMask = Builder.CreateNot (alignMask);
514
514
return Builder.CreateAnd (Builder.CreateAdd (val, alignMask), invertedMask);
515
515
}
516
+
517
+ // / Returns the current task \p currTask as an UnsafeContinuation at +1.
518
+ static llvm::Value *unsafeContinuationFromTask (IRGenFunction &IGF,
519
+ SILType unsafeContinuationTy,
520
+ llvm::Value *currTask) {
521
+ auto &IGM = IGF.IGM ;
522
+ auto &Builder = IGF.Builder ;
523
+
524
+ auto &rawPonterTI = IGM.getRawPointerTypeInfo ();
525
+ auto object =
526
+ Builder.CreateBitOrPointerCast (currTask, rawPonterTI.getStorageType ());
527
+
528
+ // Wrap the native object in the UnsafeContinuation struct.
529
+ // struct UnsafeContinuation<T> {
530
+ // let _continuation : Builtin.RawPointer
531
+ // }
532
+ auto &unsafeContinuationTI =
533
+ cast<LoadableTypeInfo>(IGF.getTypeInfo (unsafeContinuationTy));
534
+ auto unsafeContinuationStructTy =
535
+ cast<llvm::StructType>(unsafeContinuationTI.getStorageType ());
536
+ auto fieldTy =
537
+ cast<llvm::StructType>(unsafeContinuationStructTy->getElementType (0 ));
538
+ auto reference =
539
+ Builder.CreateBitOrPointerCast (object, fieldTy->getElementType (0 ));
540
+ auto field =
541
+ Builder.CreateInsertValue (llvm::UndefValue::get (fieldTy), reference, 0 );
542
+ auto unsafeContinuation = Builder.CreateInsertValue (
543
+ llvm::UndefValue::get (unsafeContinuationStructTy), field, 0 );
544
+
545
+ return unsafeContinuation;
546
+ }
547
+
548
+ void IRGenFunction::emitGetAsyncContinuation (SILType unsafeContinuationTy,
549
+ StackAddress resultAddr,
550
+ Explosion &out) {
551
+ // Create the continuation.
552
+ // void current_sil_function(AsyncTask *currTask, Executor *currExecutor,
553
+ // AsyncContext *currCtxt) {
554
+ //
555
+ // A continuation is the current AsyncTask 'currTask' with:
556
+ // currTask->ResumeTask = @llvm.coro.async.resume();
557
+ // currTask->ResumeContext = &continuation_context;
558
+ //
559
+ // Where:
560
+ //
561
+ // struct {
562
+ // AsyncContext *resumeCtxt;
563
+ // void *awaitSynchronization;
564
+ // SwiftError *errResult;
565
+ // union {
566
+ // IndirectResult *result;
567
+ // DirectResult *result;
568
+ // };
569
+ // } continuation_context; // local variable of current_sil_function
570
+ //
571
+ // continuation_context.resumeCtxt = currCtxt;
572
+ // continuation_context.errResult = nulllptr;
573
+ // continuation_context.result = ... // local alloca.
574
+
575
+ auto currTask = getAsyncTask ();
576
+ auto unsafeContinuation =
577
+ unsafeContinuationFromTask (*this , unsafeContinuationTy, currTask);
578
+
579
+ // Create and setup the continuation context for UnsafeContinuation<T>.
580
+ // continuation_context.resumeCtxt = currCtxt;
581
+ // continuation_context.errResult = nulllptr;
582
+ // continuation_context.result = ... // local alloca T
583
+ auto pointerAlignment = IGM.getPointerAlignment ();
584
+ auto continuationContext =
585
+ createAlloca (IGM.AsyncContinuationContextTy , pointerAlignment);
586
+ AsyncCoroutineCurrentContinuationContext = continuationContext.getAddress ();
587
+ // TODO: add lifetime with matching lifetime in await_async_continuation
588
+ auto contResumeAddr =
589
+ Builder.CreateStructGEP (continuationContext.getAddress (), 0 );
590
+ Builder.CreateStore (getAsyncContext (),
591
+ Address (contResumeAddr, pointerAlignment));
592
+ auto contErrResultAddr =
593
+ Builder.CreateStructGEP (continuationContext.getAddress (), 2 );
594
+ Builder.CreateStore (
595
+ llvm::Constant::getNullValue (
596
+ contErrResultAddr->getType ()->getPointerElementType ()),
597
+ Address (contErrResultAddr, pointerAlignment));
598
+ auto contResultAddr =
599
+ Builder.CreateStructGEP (continuationContext.getAddress (), 3 );
600
+ if (!resultAddr.getAddress ().isValid ()) {
601
+ assert (unsafeContinuationTy.getASTType ()
602
+ ->castTo <BoundGenericType>()
603
+ ->getGenericArgs ()
604
+ .size () == 1 &&
605
+ " expect UnsafeContinuation<T> to have one generic arg" );
606
+ auto resultTy = IGM.getLoweredType (unsafeContinuationTy.getASTType ()
607
+ ->castTo <BoundGenericType>()
608
+ ->getGenericArgs ()[0 ]
609
+ ->getCanonicalType ());
610
+ auto &resultTI = getTypeInfo (resultTy);
611
+ auto resultAddr =
612
+ resultTI.allocateStack (*this , resultTy, " async.continuation.result" );
613
+ Builder.CreateStore (Builder.CreateBitOrPointerCast (
614
+ resultAddr.getAddress ().getAddress (),
615
+ contResultAddr->getType ()->getPointerElementType ()),
616
+ Address (contResultAddr, pointerAlignment));
617
+ } else {
618
+ Builder.CreateStore (Builder.CreateBitOrPointerCast (
619
+ resultAddr.getAddress ().getAddress (),
620
+ contResultAddr->getType ()->getPointerElementType ()),
621
+ Address (contResultAddr, pointerAlignment));
622
+ }
623
+
624
+ // Fill the current task (i.e the continuation) with the continuation
625
+ // information.
626
+ // currTask->ResumeTask = @llvm.coro.async.resume();
627
+ assert (currTask->getType () == IGM.SwiftTaskPtrTy );
628
+ auto currTaskResumeTaskAddr = Builder.CreateStructGEP (currTask,3 );
629
+ auto coroResume =
630
+ Builder.CreateIntrinsicCall (llvm::Intrinsic::coro_async_resume, {});
631
+
632
+ assert (AsyncCoroutineCurrentResume == nullptr &&
633
+ " Don't support nested get_async_continuation" );
634
+ AsyncCoroutineCurrentResume = coroResume;
635
+ Builder.CreateStore (
636
+ Builder.CreateBitOrPointerCast (coroResume, IGM.FunctionPtrTy ),
637
+ Address (currTaskResumeTaskAddr, pointerAlignment));
638
+ // currTask->ResumeContext = &continuation_context;
639
+ auto currTaskResumeCtxtAddr = Builder.CreateStructGEP (currTask, 4 );
640
+ Builder.CreateStore (
641
+ Builder.CreateBitOrPointerCast (continuationContext.getAddress (),
642
+ IGM.SwiftContextPtrTy ),
643
+ Address (currTaskResumeCtxtAddr, pointerAlignment));
644
+
645
+ // Publish all the writes.
646
+ // continuation_context.awaitSynchronization =(atomic release) nullptr;
647
+ auto contAwaitSyncAddr =
648
+ Builder.CreateStructGEP (continuationContext.getAddress (), 1 );
649
+ auto null = llvm::ConstantInt::get (
650
+ contAwaitSyncAddr->getType ()->getPointerElementType (), 0 );
651
+ auto atomicStore =
652
+ Builder.CreateStore (null, Address (contAwaitSyncAddr, pointerAlignment));
653
+ atomicStore->setAtomic (llvm::AtomicOrdering::Release,
654
+ llvm::SyncScope::System);
655
+ out.add (unsafeContinuation);
656
+ }
657
+
658
+ void IRGenFunction::emitAwaitAsyncContinuation (
659
+ SILType unsafeContinuationTy, bool isIndirectResult,
660
+ Explosion &outDirectResult, llvm::BasicBlock *&normalBB,
661
+ llvm::PHINode *&optionalErrorResult, llvm::BasicBlock *&optionalErrorBB) {
662
+ assert (AsyncCoroutineCurrentContinuationContext && " no active continuation" );
663
+ auto pointerAlignment = IGM.getPointerAlignment ();
664
+
665
+ // First check whether the await reached this point first. Meaning we still
666
+ // have to wait for the continuation result. If the await reaches first we
667
+ // abort the control flow here (resuming the continuation will execute the
668
+ // remaining control flow).
669
+ auto contAwaitSyncAddr =
670
+ Builder.CreateStructGEP (AsyncCoroutineCurrentContinuationContext, 1 );
671
+ auto null = llvm::ConstantInt::get (
672
+ contAwaitSyncAddr->getType ()->getPointerElementType (), 0 );
673
+ auto one = llvm::ConstantInt::get (
674
+ contAwaitSyncAddr->getType ()->getPointerElementType (), 1 );
675
+ auto results = Builder.CreateAtomicCmpXchg (
676
+ contAwaitSyncAddr, null, one,
677
+ llvm::AtomicOrdering::AcquireRelease /* success ordering*/ ,
678
+ llvm::AtomicOrdering::Monotonic /* failure ordering */ ,
679
+ llvm::SyncScope::System);
680
+ auto firstAtAwait = Builder.CreateExtractValue (results, 1 );
681
+ auto contBB = createBasicBlock (" await.async.maybe.resume" );
682
+ auto abortBB = createBasicBlock (" await.async.abort" );
683
+ Builder.CreateCondBr (firstAtAwait, abortBB, contBB);
684
+ Builder.emitBlock (abortBB);
685
+ {
686
+ // We are first to the sync point. Abort. The continuation's result is not
687
+ // available yet.
688
+ emitCoroutineOrAsyncExit ();
689
+ }
690
+
691
+ auto contBB2 = createBasicBlock (" await.async.resume" );
692
+ Builder.emitBlock (contBB);
693
+ {
694
+ // Setup the suspend point.
695
+ SmallVector<llvm::Value *, 8 > arguments;
696
+ arguments.push_back (AsyncCoroutineCurrentResume);
697
+ auto resumeProjFn = getOrCreateResumePrjFn ();
698
+ arguments.push_back (
699
+ Builder.CreateBitOrPointerCast (resumeProjFn, IGM.Int8PtrTy ));
700
+ arguments.push_back (Builder.CreateBitOrPointerCast (
701
+ getOrCreateAwaitAsyncSupendFn (), IGM.Int8PtrTy ));
702
+ arguments.push_back (AsyncCoroutineCurrentResume);
703
+ arguments.push_back (
704
+ Builder.CreateBitOrPointerCast (getAsyncTask (), IGM.Int8PtrTy ));
705
+ arguments.push_back (
706
+ Builder.CreateBitOrPointerCast (getAsyncExecutor (), IGM.Int8PtrTy ));
707
+ arguments.push_back (Builder.CreateBitOrPointerCast (
708
+ AsyncCoroutineCurrentContinuationContext, IGM.Int8PtrTy ));
709
+ auto *id = Builder.CreateIntrinsicCall (llvm::Intrinsic::coro_suspend_async,
710
+
711
+ arguments);
712
+ auto results = Builder.CreateAtomicCmpXchg (
713
+ contAwaitSyncAddr, null, one,
714
+ llvm::AtomicOrdering::AcquireRelease /* success ordering*/ ,
715
+ llvm::AtomicOrdering::Monotonic /* failure ordering */ ,
716
+ llvm::SyncScope::System);
717
+ // Again, are we first at the wait (can only reach that state after
718
+ // continuation.resume/abort is called)? If so abort to wait for the end of
719
+ // the await point to be reached.
720
+ auto firstAtAwait = Builder.CreateExtractValue (results, 1 );
721
+ Builder.CreateCondBr (firstAtAwait, abortBB, contBB2);
722
+ }
723
+
724
+ Builder.emitBlock (contBB2);
725
+ auto contBB3 = createBasicBlock (" await.async.normal" );
726
+ if (optionalErrorBB) {
727
+ auto contErrResultAddr = Address (
728
+ Builder.CreateStructGEP (AsyncCoroutineCurrentContinuationContext, 2 ),
729
+ pointerAlignment);
730
+ auto errorRes = Builder.CreateLoad (contErrResultAddr);
731
+ auto nullError = llvm::Constant::getNullValue (errorRes->getType ());
732
+ auto hasError = Builder.CreateICmpNE (errorRes, nullError);
733
+ optionalErrorResult->addIncoming (errorRes, Builder.GetInsertBlock ());
734
+ Builder.CreateCondBr (hasError, optionalErrorBB, contBB3);
735
+ } else {
736
+ Builder.CreateBr (contBB3);
737
+ }
738
+
739
+ Builder.emitBlock (contBB3);
740
+ if (!isIndirectResult) {
741
+ auto contResultAddrAddr =
742
+ Builder.CreateStructGEP (AsyncCoroutineCurrentContinuationContext, 3 );
743
+ auto resultAddrVal =
744
+ Builder.CreateLoad (Address (contResultAddrAddr, pointerAlignment));
745
+ // Take the result.
746
+ auto resultTy = IGM.getLoweredType (unsafeContinuationTy.getASTType ()
747
+ ->castTo <BoundGenericType>()
748
+ ->getGenericArgs ()[0 ]
749
+ ->getCanonicalType ());
750
+ auto &resultTI = cast<LoadableTypeInfo>(getTypeInfo (resultTy));
751
+ auto resultStorageTy = resultTI.getStorageType ();
752
+ auto resultAddr =
753
+ Address (Builder.CreateBitOrPointerCast (resultAddrVal,
754
+ resultStorageTy->getPointerTo ()),
755
+ resultTI.getFixedAlignment ());
756
+ resultTI.loadAsTake (*this , resultAddr, outDirectResult);
757
+ }
758
+ Builder.CreateBr (normalBB);
759
+ AsyncCoroutineCurrentResume = nullptr ;
760
+ AsyncCoroutineCurrentContinuationContext = nullptr ;
761
+ }
0 commit comments