Skip to content

Commit d75e4c2

Browse files
committed
IRGen: get/await_async_continuation support.
rdar://71124933
1 parent 6722d71 commit d75e4c2

File tree

7 files changed

+441
-32
lines changed

7 files changed

+441
-32
lines changed

lib/IRGen/GenFunc.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,3 +2433,31 @@ IRGenFunction::createAsyncDispatchFn(const FunctionPointer &fnPtr,
24332433
Builder.CreateRetVoid();
24342434
return dispatch;
24352435
}
2436+
2437+
llvm::Function *IRGenFunction::getOrCreateAwaitAsyncSupendFn() {
2438+
auto name = "__swift_async_await_async_suspend";
2439+
return cast<llvm::Function>(IGM.getOrCreateHelperFunction(
2440+
name, IGM.VoidTy, {IGM.Int8PtrTy, IGM.Int8PtrTy, IGM.Int8PtrTy, IGM.Int8PtrTy},
2441+
[&](IRGenFunction &IGF) {
2442+
auto it = IGF.CurFn->arg_begin();
2443+
auto &Builder = IGF.Builder;
2444+
auto *fnTy = llvm::FunctionType::get(
2445+
IGM.VoidTy, {IGM.Int8PtrTy, IGM.Int8PtrTy, IGM.Int8PtrTy},
2446+
false /*vaargs*/);
2447+
llvm::Value *fn = &*(it++);
2448+
SmallVector<llvm::Value *, 8> callArgs;
2449+
for (auto end = IGF.CurFn->arg_end(); it != end; ++it)
2450+
callArgs.push_back(&*it);
2451+
auto signature =
2452+
Signature(fnTy, IGM.constructInitialAttributes(), IGM.SwiftCC);
2453+
auto fnPtr = FunctionPointer(
2454+
FunctionPointer::KindTy::Function,
2455+
Builder.CreateBitOrPointerCast(fn, fnTy->getPointerTo()),
2456+
PointerAuthInfo(), signature);
2457+
auto call = Builder.CreateCall(fnPtr, callArgs);
2458+
call->setTailCall();
2459+
call->setCallingConv(IGM.SwiftCC);
2460+
Builder.CreateRetVoid();
2461+
},
2462+
false /*isNoInline*/));
2463+
}

lib/IRGen/IRGenFunction.cpp

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,249 @@ llvm::Value *IRGenFunction::alignUpToMaximumAlignment(llvm::Type *sizeTy, llvm::
513513
auto *invertedMask = Builder.CreateNot(alignMask);
514514
return Builder.CreateAnd(Builder.CreateAdd(val, alignMask), invertedMask);
515515
}
516+
517+
/// Returns the current task \p currTask as an UnsafeContinuation at +1.
518+
static llvm::Value *unsafeContinuationFromTask(IRGenFunction &IGF,
519+
SILType unsafeContinuationTy,
520+
llvm::Value *currTask) {
521+
auto &IGM = IGF.IGM;
522+
auto &Builder = IGF.Builder;
523+
524+
auto &rawPonterTI = IGM.getRawPointerTypeInfo();
525+
auto object =
526+
Builder.CreateBitOrPointerCast(currTask, rawPonterTI.getStorageType());
527+
528+
// Wrap the native object in the UnsafeContinuation struct.
529+
// struct UnsafeContinuation<T> {
530+
// let _continuation : Builtin.RawPointer
531+
// }
532+
auto &unsafeContinuationTI =
533+
cast<LoadableTypeInfo>(IGF.getTypeInfo(unsafeContinuationTy));
534+
auto unsafeContinuationStructTy =
535+
cast<llvm::StructType>(unsafeContinuationTI.getStorageType());
536+
auto fieldTy =
537+
cast<llvm::StructType>(unsafeContinuationStructTy->getElementType(0));
538+
auto reference =
539+
Builder.CreateBitOrPointerCast(object, fieldTy->getElementType(0));
540+
auto field =
541+
Builder.CreateInsertValue(llvm::UndefValue::get(fieldTy), reference, 0);
542+
auto unsafeContinuation = Builder.CreateInsertValue(
543+
llvm::UndefValue::get(unsafeContinuationStructTy), field, 0);
544+
545+
return unsafeContinuation;
546+
}
547+
548+
void IRGenFunction::emitGetAsyncContinuation(SILType unsafeContinuationTy,
549+
StackAddress resultAddr,
550+
Explosion &out) {
551+
// Create the continuation.
552+
// void current_sil_function(AsyncTask *currTask, Executor *currExecutor,
553+
// AsyncContext *currCtxt) {
554+
//
555+
// A continuation is the current AsyncTask 'currTask' with:
556+
// currTask->ResumeTask = @llvm.coro.async.resume();
557+
// currTask->ResumeContext = &continuation_context;
558+
//
559+
// Where:
560+
//
561+
// struct {
562+
// AsyncContext *resumeCtxt;
563+
// void *awaitSynchronization;
564+
// SwiftError *errResult;
565+
// union {
566+
// IndirectResult *result;
567+
// DirectResult *result;
568+
// };
569+
// } continuation_context; // local variable of current_sil_function
570+
//
571+
// continuation_context.resumeCtxt = currCtxt;
572+
// continuation_context.errResult = nulllptr;
573+
// continuation_context.result = ... // local alloca.
574+
575+
auto currTask = getAsyncTask();
576+
auto unsafeContinuation =
577+
unsafeContinuationFromTask(*this, unsafeContinuationTy, currTask);
578+
579+
// Create and setup the continuation context for UnsafeContinuation<T>.
580+
// continuation_context.resumeCtxt = currCtxt;
581+
// continuation_context.errResult = nulllptr;
582+
// continuation_context.result = ... // local alloca T
583+
auto pointerAlignment = IGM.getPointerAlignment();
584+
auto continuationContext =
585+
createAlloca(IGM.AsyncContinuationContextTy, pointerAlignment);
586+
AsyncCoroutineCurrentContinuationContext = continuationContext.getAddress();
587+
// TODO: add lifetime with matching lifetime in await_async_continuation
588+
auto contResumeAddr =
589+
Builder.CreateStructGEP(continuationContext.getAddress(), 0);
590+
Builder.CreateStore(getAsyncContext(),
591+
Address(contResumeAddr, pointerAlignment));
592+
auto contErrResultAddr =
593+
Builder.CreateStructGEP(continuationContext.getAddress(), 2);
594+
Builder.CreateStore(
595+
llvm::Constant::getNullValue(
596+
contErrResultAddr->getType()->getPointerElementType()),
597+
Address(contErrResultAddr, pointerAlignment));
598+
auto contResultAddr =
599+
Builder.CreateStructGEP(continuationContext.getAddress(), 3);
600+
if (!resultAddr.getAddress().isValid()) {
601+
assert(unsafeContinuationTy.getASTType()
602+
->castTo<BoundGenericType>()
603+
->getGenericArgs()
604+
.size() == 1 &&
605+
"expect UnsafeContinuation<T> to have one generic arg");
606+
auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType()
607+
->castTo<BoundGenericType>()
608+
->getGenericArgs()[0]
609+
->getCanonicalType());
610+
auto &resultTI = getTypeInfo(resultTy);
611+
auto resultAddr =
612+
resultTI.allocateStack(*this, resultTy, "async.continuation.result");
613+
Builder.CreateStore(Builder.CreateBitOrPointerCast(
614+
resultAddr.getAddress().getAddress(),
615+
contResultAddr->getType()->getPointerElementType()),
616+
Address(contResultAddr, pointerAlignment));
617+
} else {
618+
Builder.CreateStore(Builder.CreateBitOrPointerCast(
619+
resultAddr.getAddress().getAddress(),
620+
contResultAddr->getType()->getPointerElementType()),
621+
Address(contResultAddr, pointerAlignment));
622+
}
623+
624+
// Fill the current task (i.e the continuation) with the continuation
625+
// information.
626+
// currTask->ResumeTask = @llvm.coro.async.resume();
627+
assert(currTask->getType() == IGM.SwiftTaskPtrTy);
628+
auto currTaskResumeTaskAddr = Builder.CreateStructGEP(currTask,3);
629+
auto coroResume =
630+
Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_async_resume, {});
631+
632+
assert(AsyncCoroutineCurrentResume == nullptr &&
633+
"Don't support nested get_async_continuation");
634+
AsyncCoroutineCurrentResume = coroResume;
635+
Builder.CreateStore(
636+
Builder.CreateBitOrPointerCast(coroResume, IGM.FunctionPtrTy),
637+
Address(currTaskResumeTaskAddr, pointerAlignment));
638+
// currTask->ResumeContext = &continuation_context;
639+
auto currTaskResumeCtxtAddr = Builder.CreateStructGEP(currTask, 4);
640+
Builder.CreateStore(
641+
Builder.CreateBitOrPointerCast(continuationContext.getAddress(),
642+
IGM.SwiftContextPtrTy),
643+
Address(currTaskResumeCtxtAddr, pointerAlignment));
644+
645+
// Publish all the writes.
646+
// continuation_context.awaitSynchronization =(atomic release) nullptr;
647+
auto contAwaitSyncAddr =
648+
Builder.CreateStructGEP(continuationContext.getAddress(), 1);
649+
auto null = llvm::ConstantInt::get(
650+
contAwaitSyncAddr->getType()->getPointerElementType(), 0);
651+
auto atomicStore =
652+
Builder.CreateStore(null, Address(contAwaitSyncAddr, pointerAlignment));
653+
atomicStore->setAtomic(llvm::AtomicOrdering::Release,
654+
llvm::SyncScope::System);
655+
out.add(unsafeContinuation);
656+
}
657+
658+
void IRGenFunction::emitAwaitAsyncContinuation(
659+
SILType unsafeContinuationTy, bool isIndirectResult,
660+
Explosion &outDirectResult, llvm::BasicBlock *&normalBB,
661+
llvm::PHINode *&optionalErrorResult, llvm::BasicBlock *&optionalErrorBB) {
662+
assert(AsyncCoroutineCurrentContinuationContext && "no active continuation");
663+
auto pointerAlignment = IGM.getPointerAlignment();
664+
665+
// First check whether the await reached this point first. Meaning we still
666+
// have to wait for the continuation result. If the await reaches first we
667+
// abort the control flow here (resuming the continuation will execute the
668+
// remaining control flow).
669+
auto contAwaitSyncAddr =
670+
Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 1);
671+
auto null = llvm::ConstantInt::get(
672+
contAwaitSyncAddr->getType()->getPointerElementType(), 0);
673+
auto one = llvm::ConstantInt::get(
674+
contAwaitSyncAddr->getType()->getPointerElementType(), 1);
675+
auto results = Builder.CreateAtomicCmpXchg(
676+
contAwaitSyncAddr, null, one,
677+
llvm::AtomicOrdering::AcquireRelease /*success ordering*/,
678+
llvm::AtomicOrdering::Monotonic /* failure ordering */,
679+
llvm::SyncScope::System);
680+
auto firstAtAwait = Builder.CreateExtractValue(results, 1);
681+
auto contBB = createBasicBlock("await.async.maybe.resume");
682+
auto abortBB = createBasicBlock("await.async.abort");
683+
Builder.CreateCondBr(firstAtAwait, abortBB, contBB);
684+
Builder.emitBlock(abortBB);
685+
{
686+
// We are first to the sync point. Abort. The continuation's result is not
687+
// available yet.
688+
emitCoroutineOrAsyncExit();
689+
}
690+
691+
auto contBB2 = createBasicBlock("await.async.resume");
692+
Builder.emitBlock(contBB);
693+
{
694+
// Setup the suspend point.
695+
SmallVector<llvm::Value *, 8> arguments;
696+
arguments.push_back(AsyncCoroutineCurrentResume);
697+
auto resumeProjFn = getOrCreateResumePrjFn();
698+
arguments.push_back(
699+
Builder.CreateBitOrPointerCast(resumeProjFn, IGM.Int8PtrTy));
700+
arguments.push_back(Builder.CreateBitOrPointerCast(
701+
getOrCreateAwaitAsyncSupendFn(), IGM.Int8PtrTy));
702+
arguments.push_back(AsyncCoroutineCurrentResume);
703+
arguments.push_back(
704+
Builder.CreateBitOrPointerCast(getAsyncTask(), IGM.Int8PtrTy));
705+
arguments.push_back(
706+
Builder.CreateBitOrPointerCast(getAsyncExecutor(), IGM.Int8PtrTy));
707+
arguments.push_back(Builder.CreateBitOrPointerCast(
708+
AsyncCoroutineCurrentContinuationContext, IGM.Int8PtrTy));
709+
auto *id = Builder.CreateIntrinsicCall(llvm::Intrinsic::coro_suspend_async,
710+
711+
arguments);
712+
auto results = Builder.CreateAtomicCmpXchg(
713+
contAwaitSyncAddr, null, one,
714+
llvm::AtomicOrdering::AcquireRelease /*success ordering*/,
715+
llvm::AtomicOrdering::Monotonic /* failure ordering */,
716+
llvm::SyncScope::System);
717+
// Again, are we first at the wait (can only reach that state after
718+
// continuation.resume/abort is called)? If so abort to wait for the end of
719+
// the await point to be reached.
720+
auto firstAtAwait = Builder.CreateExtractValue(results, 1);
721+
Builder.CreateCondBr(firstAtAwait, abortBB, contBB2);
722+
}
723+
724+
Builder.emitBlock(contBB2);
725+
auto contBB3 = createBasicBlock("await.async.normal");
726+
if (optionalErrorBB) {
727+
auto contErrResultAddr = Address(
728+
Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 2),
729+
pointerAlignment);
730+
auto errorRes = Builder.CreateLoad(contErrResultAddr);
731+
auto nullError = llvm::Constant::getNullValue(errorRes->getType());
732+
auto hasError = Builder.CreateICmpNE(errorRes, nullError);
733+
optionalErrorResult->addIncoming(errorRes, Builder.GetInsertBlock());
734+
Builder.CreateCondBr(hasError, optionalErrorBB, contBB3);
735+
} else {
736+
Builder.CreateBr(contBB3);
737+
}
738+
739+
Builder.emitBlock(contBB3);
740+
if (!isIndirectResult) {
741+
auto contResultAddrAddr =
742+
Builder.CreateStructGEP(AsyncCoroutineCurrentContinuationContext, 3);
743+
auto resultAddrVal =
744+
Builder.CreateLoad(Address(contResultAddrAddr, pointerAlignment));
745+
// Take the result.
746+
auto resultTy = IGM.getLoweredType(unsafeContinuationTy.getASTType()
747+
->castTo<BoundGenericType>()
748+
->getGenericArgs()[0]
749+
->getCanonicalType());
750+
auto &resultTI = cast<LoadableTypeInfo>(getTypeInfo(resultTy));
751+
auto resultStorageTy = resultTI.getStorageType();
752+
auto resultAddr =
753+
Address(Builder.CreateBitOrPointerCast(resultAddrVal,
754+
resultStorageTy->getPointerTo()),
755+
resultTI.getFixedAlignment());
756+
resultTI.loadAsTake(*this, resultAddr, outDirectResult);
757+
}
758+
Builder.CreateBr(normalBB);
759+
AsyncCoroutineCurrentResume = nullptr;
760+
AsyncCoroutineCurrentContinuationContext = nullptr;
761+
}

lib/IRGen/IRGenFunction.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ class IRGenFunction {
135135
llvm::Function *getOrCreateResumePrjFn();
136136
llvm::Function *createAsyncDispatchFn(const FunctionPointer &fnPtr,
137137
ArrayRef<llvm::Value *> args);
138+
llvm::Function *getOrCreateAwaitAsyncSupendFn();
139+
140+
void emitGetAsyncContinuation(SILType silTy, StackAddress optionalResultAddr,
141+
Explosion &out);
142+
143+
void emitAwaitAsyncContinuation(SILType unsafeContinuationTy,
144+
bool isIndirectResult,
145+
Explosion &outDirectResult,
146+
llvm::BasicBlock *&normalBB,
147+
llvm::PHINode *&optionalErrorPhi,
148+
llvm::BasicBlock *&optionalErrorBB);
138149

139150
private:
140151
void emitPrologue();
@@ -145,8 +156,16 @@ class IRGenFunction {
145156
llvm::Value *CalleeErrorResultSlot = nullptr;
146157
llvm::Value *CallerErrorResultSlot = nullptr;
147158
llvm::Value *CoroutineHandle = nullptr;
159+
llvm::Value *AsyncCoroutineCurrentResume = nullptr;
160+
llvm::Value *AsyncCoroutineCurrentContinuationContext = nullptr;
148161
bool IsAsync = false;
149162

163+
/// The unique block that calls @llvm.coro.end.
164+
llvm::BasicBlock *CoroutineExitBlock = nullptr;
165+
166+
public:
167+
void emitCoroutineOrAsyncExit();
168+
150169
//--- Helper methods -----------------------------------------------------------
151170
public:
152171

lib/IRGen/IRGenModule.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,14 @@ IRGenModule::IRGenModule(IRGenerator &irgen,
593593
AsyncFunctionPointerTy = createStructType(*this, "swift.async_func_pointer",
594594
{RelativeAddressTy, Int32Ty}, true);
595595
SwiftContextTy = createStructType(*this, "swift.context", {});
596-
SwiftTaskTy = createStructType(*this, "swift.task", {});
596+
auto *ContextPtrTy = llvm::PointerType::getUnqual(SwiftContextTy);
597+
SwiftTaskTy = createStructType(*this, "swift.task", {
598+
Int8PtrTy, Int8PtrTy, // Job.SchedulerPrivate
599+
Int64Ty, // Job.Flags
600+
FunctionPtrTy, // Job.RunJob/Job.ResumeTask
601+
ContextPtrTy, // Task.ResumeContext
602+
Int64Ty // Task.Status
603+
});
597604
SwiftExecutorTy = createStructType(*this, "swift.executor", {});
598605
AsyncFunctionPointerPtrTy = AsyncFunctionPointerTy->getPointerTo(DefaultAS);
599606
SwiftContextPtrTy = SwiftContextTy->getPointerTo(DefaultAS);
@@ -612,6 +619,11 @@ IRGenModule::IRGenModule(IRGenerator &irgen,
612619
*this, "swift.async_task_and_context",
613620
{ SwiftTaskPtrTy, SwiftContextPtrTy });
614621

622+
AsyncContinuationContextTy =
623+
createStructType(*this, "swift.async_continuation_context",
624+
{SwiftContextPtrTy, SizeTy, ErrorPtrTy, OpaquePtrTy});
625+
AsyncContinuationContextPtrTy = AsyncContinuationContextTy->getPointerTo();
626+
615627
DifferentiabilityWitnessTy = createStructType(
616628
*this, "swift.differentiability_witness", {Int8PtrTy, Int8PtrTy});
617629
}

lib/IRGen/IRGenModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,8 @@ class IRGenModule {
734734
llvm::FunctionType *TaskContinuationFunctionTy;
735735
llvm::PointerType *TaskContinuationFunctionPtrTy;
736736
llvm::StructType *AsyncTaskAndContextTy;
737+
llvm::StructType *AsyncContinuationContextTy;
738+
llvm::PointerType *AsyncContinuationContextPtrTy;
737739
llvm::StructType *DifferentiabilityWitnessTy; // { i8*, i8* }
738740

739741
llvm::GlobalVariable *TheTrivialPropertyDescriptor = nullptr;

0 commit comments

Comments
 (0)