Skip to content

Commit de2dbe5

Browse files
authored
[AutoDiff] Bump-pointer allocate pullback structs in loops. (#34886)
In derivatives of loops, no longer allocate boxes for indirect case payloads. Instead, use a custom pullback context in the runtime which contains a bump-pointer allocator. When a function contains a differentiated loop, the closure context is a `Builtin.NativeObject`, which contains a `swift::AutoDiffLinearMapContext` and a tail-allocated top-level linear map struct (which represents the linear map struct that was previously directly partial-applied into the pullback). In branching trace enums, the payloads of previously indirect cases will be allocated by `swift::AutoDiffLinearMapContext::allocate` and stored as a `Builtin.RawPointer`.
1 parent 9dc2a71 commit de2dbe5

34 files changed

+644
-100
lines changed

include/swift/AST/ASTContext.h

+3
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,9 @@ class ASTContext final {
713713
/// Get the runtime availability of support for concurrency.
714714
AvailabilityContext getConcurrencyAvailability();
715715

716+
/// Get the runtime availability of support for differentiation.
717+
AvailabilityContext getDifferentiationAvailability();
718+
716719
/// Get the runtime availability of features introduced in the Swift 5.2
717720
/// compiler for the target platform.
718721
AvailabilityContext getSwift52Availability();

include/swift/AST/Builtins.def

+9
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,15 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskFuture,
752752
/// is a pure value and therefore we can consider it as readnone).
753753
BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special)
754754

755+
// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject
756+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "n", Special)
757+
758+
// autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer
759+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special)
760+
761+
// autoDiffAllocateSubcontext: (Builtin.NativeObject, Builtin.Word) -> Builtin.RawPointer
762+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontext, "autoDiffAllocateSubcontext", "", Special)
763+
755764
#undef BUILTIN_MISC_OPERATION_WITH_SILGEN
756765

757766
#undef BUILTIN_MISC_OPERATION

include/swift/Runtime/RuntimeFunctions.def

+24
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,30 @@ FUNCTION(TaskCreateFutureFunc,
15181518
TaskContinuationFunctionPtrTy, SizeTy),
15191519
ATTRS(NoUnwind, ArgMemOnly))
15201520

1521+
// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(size_t);
1522+
FUNCTION(AutoDiffCreateLinearMapContext,
1523+
swift_autoDiffCreateLinearMapContext, SwiftCC,
1524+
DifferentiationAvailability,
1525+
RETURNS(RefCountedPtrTy),
1526+
ARGS(SizeTy),
1527+
ATTRS(NoUnwind, ArgMemOnly))
1528+
1529+
// void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *);
1530+
FUNCTION(AutoDiffProjectTopLevelSubcontext,
1531+
swift_autoDiffProjectTopLevelSubcontext, SwiftCC,
1532+
DifferentiationAvailability,
1533+
RETURNS(Int8PtrTy),
1534+
ARGS(RefCountedPtrTy),
1535+
ATTRS(NoUnwind, ArgMemOnly))
1536+
1537+
// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t);
1538+
FUNCTION(AutoDiffAllocateSubcontext,
1539+
swift_autoDiffAllocateSubcontext, SwiftCC,
1540+
DifferentiationAvailability,
1541+
RETURNS(Int8PtrTy),
1542+
ARGS(RefCountedPtrTy, SizeTy),
1543+
ATTRS(NoUnwind, ArgMemOnly))
1544+
15211545
#undef RETURNS
15221546
#undef ARGS
15231547
#undef ATTRS

include/swift/SILOptimizer/Differentiation/Common.h

+10
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ void extractAllElements(SILValue value, SILBuilder &builder,
192192
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
193193
SILValue bufferAccess, SILLocation loc);
194194

195+
/// Emit a `Builtin.Word` value that represents the given type's memory layout
196+
/// size.
197+
SILValue emitMemoryLayoutSize(
198+
SILBuilder &builder, SILLocation loc, CanType type);
199+
200+
/// Emit a projection of the top-level subcontext from the context object.
201+
SILValue emitProjectTopLevelSubcontext(
202+
SILBuilder &builder, SILLocation loc, SILValue context,
203+
SILType subcontextType);
204+
195205
//===----------------------------------------------------------------------===//
196206
// Utilities for looking up derivatives of functions
197207
//===----------------------------------------------------------------------===//

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

+18-3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class LinearMapInfo {
6363
/// Activity info of the original function.
6464
const DifferentiableActivityInfo &activityInfo;
6565

66+
/// The original function's loop info.
67+
SILLoopInfo *loopInfo;
68+
6669
/// Differentiation indices of the function.
6770
const SILAutoDiffIndices indices;
6871

@@ -86,6 +89,9 @@ class LinearMapInfo {
8689
/// Mapping from linear map structs to their branching trace enum fields.
8790
llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;
8891

92+
/// Blocks in a loop.
93+
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;
94+
8995
/// A synthesized file unit.
9096
SynthesizedFileUnit &synthesizedFile;
9197

@@ -144,7 +150,8 @@ class LinearMapInfo {
144150
explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
145151
SILFunction *original, SILFunction *derivative,
146152
SILAutoDiffIndices indices,
147-
const DifferentiableActivityInfo &activityInfo);
153+
const DifferentiableActivityInfo &activityInfo,
154+
SILLoopInfo *loopInfo);
148155

149156
/// Returns the linear map struct associated with the given original block.
150157
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
@@ -200,20 +207,28 @@ class LinearMapInfo {
200207

201208
/// Returns the branching trace enum field for the linear map struct of the
202209
/// given original block.
203-
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) {
210+
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const {
204211
auto *linearMapStruct = getLinearMapStruct(origBB);
205212
return linearMapStructEnumFields.lookup(linearMapStruct);
206213
}
207214

208215
/// Finds the linear map declaration in the pullback struct for the given
209216
/// `apply` instruction in the original function.
210-
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
217+
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const {
211218
assert(ai->getFunction() == original);
212219
auto lookup = linearMapFieldMap.find(ai);
213220
assert(lookup != linearMapFieldMap.end() &&
214221
"No linear map field corresponding to the given `apply`");
215222
return lookup->getSecond();
216223
}
224+
225+
bool hasLoops() const {
226+
return !blocksInLoop.empty();
227+
}
228+
229+
ArrayRef<SILBasicBlock *> getBlocksInLoop() const {
230+
return blocksInLoop.getArrayRef();
231+
}
217232
};
218233

219234
} // end namespace autodiff

include/swift/SILOptimizer/Differentiation/VJPCloner.h

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2222
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2323
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
24+
#include "swift/SIL/LoopInfo.h"
2425

2526
namespace swift {
2627
namespace autodiff {
@@ -52,6 +53,7 @@ class VJPCloner final {
5253
const SILAutoDiffIndices getIndices() const;
5354
DifferentiationInvoker getInvoker() const;
5455
LinearMapInfo &getPullbackInfo() const;
56+
SILLoopInfo *getLoopInfo() const;
5557
const DifferentiableActivityInfo &getActivityInfo() const;
5658

5759
/// Performs VJP generation on the empty VJP function. Returns true if any

lib/AST/ASTMangler.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,8 @@ static StringRef getPrivateDiscriminatorIfNecessary(const ValueDecl *decl) {
785785

786786
// Mangle non-local private declarations with a textual discriminator
787787
// based on their enclosing file.
788-
auto topLevelContext = decl->getDeclContext()->getModuleScopeContext();
789-
auto fileUnit = cast<FileUnit>(topLevelContext);
788+
auto topLevelSubcontext = decl->getDeclContext()->getModuleScopeContext();
789+
auto fileUnit = cast<FileUnit>(topLevelSubcontext);
790790

791791
Identifier discriminator =
792792
fileUnit->getDiscriminatorForPrivateValue(decl);
@@ -2900,17 +2900,17 @@ void ASTMangler::appendEntity(const ValueDecl *decl) {
29002900
void
29012901
ASTMangler::appendProtocolConformance(const ProtocolConformance *conformance) {
29022902
GenericSignature contextSig;
2903-
auto topLevelContext =
2903+
auto topLevelSubcontext =
29042904
conformance->getDeclContext()->getModuleScopeContext();
2905-
Mod = topLevelContext->getParentModule();
2905+
Mod = topLevelSubcontext->getParentModule();
29062906

29072907
auto conformingType = conformance->getType();
29082908
appendType(conformingType->getCanonicalType());
29092909

29102910
appendProtocolName(conformance->getProtocol());
29112911

29122912
bool needsModule = true;
2913-
if (auto *file = dyn_cast<FileUnit>(topLevelContext)) {
2913+
if (auto *file = dyn_cast<FileUnit>(topLevelSubcontext)) {
29142914
if (file->getKind() == FileUnitKind::ClangModule ||
29152915
file->getKind() == FileUnitKind::DWARFModule) {
29162916
if (conformance->getProtocol()->hasClangNode())

lib/AST/ASTVerifier.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ class Verifier : public ASTWalker {
229229
typedef llvm::PointerIntPair<DeclContext *, 1, bool> ClosureDiscriminatorKey;
230230
llvm::DenseMap<ClosureDiscriminatorKey, SmallBitVector>
231231
ClosureDiscriminators;
232-
DeclContext *CanonicalTopLevelContext = nullptr;
232+
DeclContext *CanonicalTopLevelSubcontext = nullptr;
233233

234234
Verifier(PointerUnion<ModuleDecl *, SourceFile *> M, DeclContext *DC)
235235
: M(M),
@@ -898,9 +898,9 @@ class Verifier : public ASTWalker {
898898
DeclContext *getCanonicalDeclContext(DeclContext *DC) {
899899
// All we really need to do is use a single TopLevelCodeDecl.
900900
if (auto topLevel = dyn_cast<TopLevelCodeDecl>(DC)) {
901-
if (!CanonicalTopLevelContext)
902-
CanonicalTopLevelContext = topLevel;
903-
return CanonicalTopLevelContext;
901+
if (!CanonicalTopLevelSubcontext)
902+
CanonicalTopLevelSubcontext = topLevel;
903+
return CanonicalTopLevelSubcontext;
904904
}
905905

906906
// TODO: check for uniqueness of initializer contexts?

lib/AST/Availability.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ AvailabilityContext ASTContext::getConcurrencyAvailability() {
327327
return getSwiftFutureAvailability();
328328
}
329329

330+
AvailabilityContext ASTContext::getDifferentiationAvailability() {
331+
return getSwiftFutureAvailability();
332+
}
333+
330334
AvailabilityContext ASTContext::getSwift52Availability() {
331335
auto target = LangOpts.Target;
332336

lib/AST/Builtins.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,25 @@ static ValueDecl *getCreateAsyncTaskFuture(ASTContext &ctx, Identifier id) {
13831383
return builder.build(id);
13841384
}
13851385

1386+
static ValueDecl *getAutoDiffCreateLinearMapContext(ASTContext &ctx,
1387+
Identifier id) {
1388+
return getBuiltinFunction(
1389+
id, {BuiltinIntegerType::getWordType(ctx)}, ctx.TheNativeObjectType);
1390+
}
1391+
1392+
static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx,
1393+
Identifier id) {
1394+
return getBuiltinFunction(
1395+
id, {ctx.TheNativeObjectType}, ctx.TheRawPointerType);
1396+
}
1397+
1398+
static ValueDecl *getAutoDiffAllocateSubcontext(ASTContext &ctx,
1399+
Identifier id) {
1400+
return getBuiltinFunction(
1401+
id, {ctx.TheNativeObjectType, BuiltinIntegerType::getWordType(ctx)},
1402+
ctx.TheRawPointerType);
1403+
}
1404+
13861405
static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) {
13871406
auto int1Type = BuiltinIntegerType::get(1, Context);
13881407
auto optionalRawPointerType = BoundGenericEnumType::get(
@@ -2549,6 +2568,15 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
25492568

25502569
case BuiltinValueKind::TriggerFallbackDiagnostic:
25512570
return getTriggerFallbackDiagnosticOperation(Context, Id);
2571+
2572+
case BuiltinValueKind::AutoDiffCreateLinearMapContext:
2573+
return getAutoDiffCreateLinearMapContext(Context, Id);
2574+
2575+
case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext:
2576+
return getAutoDiffProjectTopLevelSubcontext(Context, Id);
2577+
2578+
case BuiltinValueKind::AutoDiffAllocateSubcontext:
2579+
return getAutoDiffAllocateSubcontext(Context, Id);
25522580
}
25532581

25542582
llvm_unreachable("bad builtin value!");

lib/IDE/CodeCompletion.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ class CodeCompletionCallbacksImpl : public CodeCompletionCallbacks {
16771677
} // end anonymous namespace
16781678

16791679
namespace {
1680-
static bool isTopLevelContext(const DeclContext *DC) {
1680+
static bool isTopLevelSubcontext(const DeclContext *DC) {
16811681
for (; DC && DC->isLocalContext(); DC = DC->getParent()) {
16821682
switch (DC->getContextKind()) {
16831683
case DeclContextKind::TopLevelCodeDecl:
@@ -2139,7 +2139,7 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
21392139
if (CurrDeclContext && D->getModuleContext() == CurrModule) {
21402140
// Treat global variables from the same source file as local when
21412141
// completing at top-level.
2142-
if (isa<VarDecl>(D) && isTopLevelContext(CurrDeclContext) &&
2142+
if (isa<VarDecl>(D) && isTopLevelSubcontext(CurrDeclContext) &&
21432143
D->getDeclContext()->getParentSourceFile() ==
21442144
CurrDeclContext->getParentSourceFile()) {
21452145
return SemanticContextKind::Local;

lib/IRGen/GenBuiltin.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -1115,5 +1115,27 @@ if (Builtin.ID == BuiltinValueKind::id) { \
11151115
return;
11161116
}
11171117

1118+
if (Builtin.ID == BuiltinValueKind::AutoDiffCreateLinearMapContext) {
1119+
auto topLevelSubcontextSize = args.claimNext();
1120+
out.add(emitAutoDiffCreateLinearMapContext(IGF, topLevelSubcontextSize)
1121+
.getAddress());
1122+
return;
1123+
}
1124+
1125+
if (Builtin.ID == BuiltinValueKind::AutoDiffProjectTopLevelSubcontext) {
1126+
Address allocatorAddr(args.claimNext(), IGF.IGM.getPointerAlignment());
1127+
out.add(
1128+
emitAutoDiffProjectTopLevelSubcontext(IGF, allocatorAddr).getAddress());
1129+
return;
1130+
}
1131+
1132+
if (Builtin.ID == BuiltinValueKind::AutoDiffAllocateSubcontext) {
1133+
Address allocatorAddr(args.claimNext(), IGF.IGM.getPointerAlignment());
1134+
auto size = args.claimNext();
1135+
out.add(
1136+
emitAutoDiffAllocateSubcontext(IGF, allocatorAddr, size).getAddress());
1137+
return;
1138+
}
1139+
11181140
llvm_unreachable("IRGen unimplemented for this builtin!");
11191141
}

lib/IRGen/GenCall.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -4579,3 +4579,32 @@ IRGenFunction::getFunctionPointerForResumeIntrinsic(llvm::Value *resume) {
45794579
PointerAuthInfo(), signature);
45804580
return fnPtr;
45814581
}
4582+
4583+
Address irgen::emitAutoDiffCreateLinearMapContext(
4584+
IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize) {
4585+
auto *call = IGF.Builder.CreateCall(
4586+
IGF.IGM.getAutoDiffCreateLinearMapContextFn(), {topLevelSubcontextSize});
4587+
call->setDoesNotThrow();
4588+
call->setCallingConv(IGF.IGM.SwiftCC);
4589+
return Address(call, IGF.IGM.getPointerAlignment());
4590+
}
4591+
4592+
Address irgen::emitAutoDiffProjectTopLevelSubcontext(
4593+
IRGenFunction &IGF, Address context) {
4594+
auto *call = IGF.Builder.CreateCall(
4595+
IGF.IGM.getAutoDiffProjectTopLevelSubcontextFn(),
4596+
{context.getAddress()});
4597+
call->setDoesNotThrow();
4598+
call->setCallingConv(IGF.IGM.SwiftCC);
4599+
return Address(call, IGF.IGM.getPointerAlignment());
4600+
}
4601+
4602+
Address irgen::emitAutoDiffAllocateSubcontext(
4603+
IRGenFunction &IGF, Address context, llvm::Value *size) {
4604+
auto *call = IGF.Builder.CreateCall(
4605+
IGF.IGM.getAutoDiffAllocateSubcontextFn(),
4606+
{context.getAddress(), size});
4607+
call->setDoesNotThrow();
4608+
call->setCallingConv(IGF.IGM.SwiftCC);
4609+
return Address(call, IGF.IGM.getPointerAlignment());
4610+
}

lib/IRGen/GenCall.h

+7
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,13 @@ namespace irgen {
432432

433433
void emitAsyncReturn(IRGenFunction &IGF, AsyncContextLayout &layout,
434434
CanSILFunctionType fnType);
435+
436+
Address emitAutoDiffCreateLinearMapContext(
437+
IRGenFunction &IGF, llvm::Value *topLevelSubcontextSize);
438+
Address emitAutoDiffProjectTopLevelSubcontext(
439+
IRGenFunction &IGF, Address context);
440+
Address emitAutoDiffAllocateSubcontext(
441+
IRGenFunction &IGF, Address context, llvm::Value *size);
435442
} // end namespace irgen
436443
} // end namespace swift
437444

lib/IRGen/IRGenModule.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,14 @@ namespace RuntimeConstants {
735735
}
736736
return RuntimeAvailability::AlwaysAvailable;
737737
}
738+
739+
RuntimeAvailability DifferentiationAvailability(ASTContext &context) {
740+
auto featureAvailability = context.getDifferentiationAvailability();
741+
if (!isDeploymentAvailabilityContainedIn(context, featureAvailability)) {
742+
return RuntimeAvailability::ConditionallyAvailable;
743+
}
744+
return RuntimeAvailability::AlwaysAvailable;
745+
}
738746
} // namespace RuntimeConstants
739747

740748
// We don't use enough attributes to justify generalizing the

lib/SIL/IR/OperandOwnership.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,9 @@ CONSTANT_OWNERSHIP_BUILTIN(Owned, LifetimeEnding, UnsafeGuaranteed)
880880
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CancelAsyncTask)
881881
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CreateAsyncTask)
882882
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, CreateAsyncTaskFuture)
883+
CONSTANT_OWNERSHIP_BUILTIN(None, NonLifetimeEnding, AutoDiffCreateLinearMapContext)
884+
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, AutoDiffAllocateSubcontext)
885+
CONSTANT_OWNERSHIP_BUILTIN(Guaranteed, NonLifetimeEnding, AutoDiffProjectTopLevelSubcontext)
883886

884887
#undef CONSTANT_OWNERSHIP_BUILTIN
885888

lib/SIL/IR/ValueOwnership.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,9 @@ CONSTANT_OWNERSHIP_BUILTIN(None, GetCurrentAsyncTask)
545545
CONSTANT_OWNERSHIP_BUILTIN(None, CancelAsyncTask)
546546
CONSTANT_OWNERSHIP_BUILTIN(Owned, CreateAsyncTask)
547547
CONSTANT_OWNERSHIP_BUILTIN(Owned, CreateAsyncTaskFuture)
548+
CONSTANT_OWNERSHIP_BUILTIN(Owned, AutoDiffCreateLinearMapContext)
549+
CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffProjectTopLevelSubcontext)
550+
CONSTANT_OWNERSHIP_BUILTIN(None, AutoDiffAllocateSubcontext)
548551

549552
#undef CONSTANT_OWNERSHIP_BUILTIN
550553

lib/SIL/Utils/MemAccessUtils.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,8 @@ static void visitBuiltinAddress(BuiltinInst *builtin,
18051805
case BuiltinValueKind::CancelAsyncTask:
18061806
case BuiltinValueKind::CreateAsyncTask:
18071807
case BuiltinValueKind::CreateAsyncTaskFuture:
1808+
case BuiltinValueKind::AutoDiffCreateLinearMapContext:
1809+
case BuiltinValueKind::AutoDiffAllocateSubcontext:
18081810
return;
18091811

18101812
// General memory access to a pointer in first operand position.

0 commit comments

Comments
 (0)