Skip to content

Commit de2dbe5

Browse files
authored
[AutoDiff] Bump-pointer allocate pullback structs in loops. (#34886)
In derivatives of loops, no longer allocate boxes for indirect case payloads. Instead, use a custom pullback context in the runtime which contains a bump-pointer allocator. When a function contains a differentiated loop, the closure context is a `Builtin.NativeObject`, which contains a `swift::AutoDiffLinearMapContext` and a tail-allocated top-level linear map struct (which represents the linear map struct that was previously directly partial-applied into the pullback). In branching trace enums, the payloads of previously indirect cases will be allocated by `swift::AutoDiffLinearMapContext::allocate` and stored as a `Builtin.RawPointer`.
1 parent 9dc2a71 commit de2dbe5

34 files changed

+644
-100
lines changed

include/swift/AST/ASTContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,9 @@ class ASTContext final {
713713
/// Get the runtime availability of support for concurrency.
714714
AvailabilityContext getConcurrencyAvailability();
715715

716+
/// Get the runtime availability of support for differentiation.
717+
AvailabilityContext getDifferentiationAvailability();
718+
716719
/// Get the runtime availability of features introduced in the Swift 5.2
717720
/// compiler for the target platform.
718721
AvailabilityContext getSwift52Availability();

include/swift/AST/Builtins.def

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,15 @@ BUILTIN_MISC_OPERATION_WITH_SILGEN(CreateAsyncTaskFuture,
752752
/// is a pure value and therefore we can consider it as readnone).
753753
BUILTIN_MISC_OPERATION_WITH_SILGEN(GlobalStringTablePointer, "globalStringTablePointer", "n", Special)
754754

755+
// autoDiffCreateLinearMapContext: (Builtin.Word) -> Builtin.NativeObject
756+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffCreateLinearMapContext, "autoDiffCreateLinearMapContext", "n", Special)
757+
758+
// autoDiffProjectTopLevelSubcontext: (Builtin.NativeObject) -> Builtin.RawPointer
759+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffProjectTopLevelSubcontext, "autoDiffProjectTopLevelSubcontext", "n", Special)
760+
761+
// autoDiffAllocateSubcontext: (Builtin.NativeObject, Builtin.Word) -> Builtin.RawPointer
762+
BUILTIN_MISC_OPERATION_WITH_SILGEN(AutoDiffAllocateSubcontext, "autoDiffAllocateSubcontext", "", Special)
763+
755764
#undef BUILTIN_MISC_OPERATION_WITH_SILGEN
756765

757766
#undef BUILTIN_MISC_OPERATION

include/swift/Runtime/RuntimeFunctions.def

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,30 @@ FUNCTION(TaskCreateFutureFunc,
15181518
TaskContinuationFunctionPtrTy, SizeTy),
15191519
ATTRS(NoUnwind, ArgMemOnly))
15201520

1521+
// AutoDiffLinearMapContext *swift_autoDiffCreateLinearMapContext(size_t);
1522+
FUNCTION(AutoDiffCreateLinearMapContext,
1523+
swift_autoDiffCreateLinearMapContext, SwiftCC,
1524+
DifferentiationAvailability,
1525+
RETURNS(RefCountedPtrTy),
1526+
ARGS(SizeTy),
1527+
ATTRS(NoUnwind, ArgMemOnly))
1528+
1529+
// void *swift_autoDiffProjectTopLevelSubcontext(AutoDiffLinearMapContext *);
1530+
FUNCTION(AutoDiffProjectTopLevelSubcontext,
1531+
swift_autoDiffProjectTopLevelSubcontext, SwiftCC,
1532+
DifferentiationAvailability,
1533+
RETURNS(Int8PtrTy),
1534+
ARGS(RefCountedPtrTy),
1535+
ATTRS(NoUnwind, ArgMemOnly))
1536+
1537+
// void *swift_autoDiffAllocateSubcontext(AutoDiffLinearMapContext *, size_t);
1538+
FUNCTION(AutoDiffAllocateSubcontext,
1539+
swift_autoDiffAllocateSubcontext, SwiftCC,
1540+
DifferentiationAvailability,
1541+
RETURNS(Int8PtrTy),
1542+
ARGS(RefCountedPtrTy, SizeTy),
1543+
ATTRS(NoUnwind, ArgMemOnly))
1544+
15211545
#undef RETURNS
15221546
#undef ARGS
15231547
#undef ATTRS

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ void extractAllElements(SILValue value, SILBuilder &builder,
192192
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
193193
SILValue bufferAccess, SILLocation loc);
194194

195+
/// Emit a `Builtin.Word` value that represents the given type's memory layout
196+
/// size.
197+
SILValue emitMemoryLayoutSize(
198+
SILBuilder &builder, SILLocation loc, CanType type);
199+
200+
/// Emit a projection of the top-level subcontext from the context object.
201+
SILValue emitProjectTopLevelSubcontext(
202+
SILBuilder &builder, SILLocation loc, SILValue context,
203+
SILType subcontextType);
204+
195205
//===----------------------------------------------------------------------===//
196206
// Utilities for looking up derivatives of functions
197207
//===----------------------------------------------------------------------===//

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class LinearMapInfo {
6363
/// Activity info of the original function.
6464
const DifferentiableActivityInfo &activityInfo;
6565

66+
/// The original function's loop info.
67+
SILLoopInfo *loopInfo;
68+
6669
/// Differentiation indices of the function.
6770
const SILAutoDiffIndices indices;
6871

@@ -86,6 +89,9 @@ class LinearMapInfo {
8689
/// Mapping from linear map structs to their branching trace enum fields.
8790
llvm::DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;
8891

92+
/// Blocks in a loop.
93+
llvm::SmallSetVector<SILBasicBlock *, 4> blocksInLoop;
94+
8995
/// A synthesized file unit.
9096
SynthesizedFileUnit &synthesizedFile;
9197

@@ -144,7 +150,8 @@ class LinearMapInfo {
144150
explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
145151
SILFunction *original, SILFunction *derivative,
146152
SILAutoDiffIndices indices,
147-
const DifferentiableActivityInfo &activityInfo);
153+
const DifferentiableActivityInfo &activityInfo,
154+
SILLoopInfo *loopInfo);
148155

149156
/// Returns the linear map struct associated with the given original block.
150157
StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
@@ -200,20 +207,28 @@ class LinearMapInfo {
200207

201208
/// Returns the branching trace enum field for the linear map struct of the
202209
/// given original block.
203-
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) {
210+
VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) const {
204211
auto *linearMapStruct = getLinearMapStruct(origBB);
205212
return linearMapStructEnumFields.lookup(linearMapStruct);
206213
}
207214

208215
/// Finds the linear map declaration in the pullback struct for the given
209216
/// `apply` instruction in the original function.
210-
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
217+
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) const {
211218
assert(ai->getFunction() == original);
212219
auto lookup = linearMapFieldMap.find(ai);
213220
assert(lookup != linearMapFieldMap.end() &&
214221
"No linear map field corresponding to the given `apply`");
215222
return lookup->getSecond();
216223
}
224+
225+
bool hasLoops() const {
226+
return !blocksInLoop.empty();
227+
}
228+
229+
ArrayRef<SILBasicBlock *> getBlocksInLoop() const {
230+
return blocksInLoop.getArrayRef();
231+
}
217232
};
218233

219234
} // end namespace autodiff

include/swift/SILOptimizer/Differentiation/VJPCloner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
2222
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2323
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
24+
#include "swift/SIL/LoopInfo.h"
2425

2526
namespace swift {
2627
namespace autodiff {
@@ -52,6 +53,7 @@ class VJPCloner final {
5253
const SILAutoDiffIndices getIndices() const;
5354
DifferentiationInvoker getInvoker() const;
5455
LinearMapInfo &getPullbackInfo() const;
56+
SILLoopInfo *getLoopInfo() const;
5557
const DifferentiableActivityInfo &getActivityInfo() const;
5658

5759
/// Performs VJP generation on the empty VJP function. Returns true if any

lib/AST/ASTMangler.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,8 +785,8 @@ static StringRef getPrivateDiscriminatorIfNecessary(const ValueDecl *decl) {
785785

786786
// Mangle non-local private declarations with a textual discriminator
787787
// based on their enclosing file.
788-
auto topLevelContext = decl->getDeclContext()->getModuleScopeContext();
789-
auto fileUnit = cast<FileUnit>(topLevelContext);
788+
auto topLevelSubcontext = decl->getDeclContext()->getModuleScopeContext();
789+
auto fileUnit = cast<FileUnit>(topLevelSubcontext);
790790

791791
Identifier discriminator =
792792
fileUnit->getDiscriminatorForPrivateValue(decl);
@@ -2900,17 +2900,17 @@ void ASTMangler::appendEntity(const ValueDecl *decl) {
29002900
void
29012901
ASTMangler::appendProtocolConformance(const ProtocolConformance *conformance) {
29022902
GenericSignature contextSig;
2903-
auto topLevelContext =
2903+
auto topLevelSubcontext =
29042904
conformance->getDeclContext()->getModuleScopeContext();
2905-
Mod = topLevelContext->getParentModule();
2905+
Mod = topLevelSubcontext->getParentModule();
29062906

29072907
auto conformingType = conformance->getType();
29082908
appendType(conformingType->getCanonicalType());
29092909

29102910
appendProtocolName(conformance->getProtocol());
29112911

29122912
bool needsModule = true;
2913-
if (auto *file = dyn_cast<FileUnit>(topLevelContext)) {
2913+
if (auto *file = dyn_cast<FileUnit>(topLevelSubcontext)) {
29142914
if (file->getKind() == FileUnitKind::ClangModule ||
29152915
file->getKind() == FileUnitKind::DWARFModule) {
29162916
if (conformance->getProtocol()->hasClangNode())

lib/AST/ASTVerifier.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ class Verifier : public ASTWalker {
229229
typedef llvm::PointerIntPair<DeclContext *, 1, bool> ClosureDiscriminatorKey;
230230
llvm::DenseMap<ClosureDiscriminatorKey, SmallBitVector>
231231
ClosureDiscriminators;
232-
DeclContext *CanonicalTopLevelContext = nullptr;
232+
DeclContext *CanonicalTopLevelSubcontext = nullptr;
233233

234234
Verifier(PointerUnion<ModuleDecl *, SourceFile *> M, DeclContext *DC)
235235
: M(M),
@@ -898,9 +898,9 @@ class Verifier : public ASTWalker {
898898
DeclContext *getCanonicalDeclContext(DeclContext *DC) {
899899
// All we really need to do is use a single TopLevelCodeDecl.
900900
if (auto topLevel = dyn_cast<TopLevelCodeDecl>(DC)) {
901-
if (!CanonicalTopLevelContext)
902-
CanonicalTopLevelContext = topLevel;
903-
return CanonicalTopLevelContext;
901+
if (!CanonicalTopLevelSubcontext)
902+
CanonicalTopLevelSubcontext = topLevel;
903+
return CanonicalTopLevelSubcontext;
904904
}
905905

906906
// TODO: check for uniqueness of initializer contexts?

lib/AST/Availability.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ AvailabilityContext ASTContext::getConcurrencyAvailability() {
327327
return getSwiftFutureAvailability();
328328
}
329329

330+
AvailabilityContext ASTContext::getDifferentiationAvailability() {
331+
return getSwiftFutureAvailability();
332+
}
333+
330334
AvailabilityContext ASTContext::getSwift52Availability() {
331335
auto target = LangOpts.Target;
332336

lib/AST/Builtins.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,25 @@ static ValueDecl *getCreateAsyncTaskFuture(ASTContext &ctx, Identifier id) {
13831383
return builder.build(id);
13841384
}
13851385

1386+
static ValueDecl *getAutoDiffCreateLinearMapContext(ASTContext &ctx,
1387+
Identifier id) {
1388+
return getBuiltinFunction(
1389+
id, {BuiltinIntegerType::getWordType(ctx)}, ctx.TheNativeObjectType);
1390+
}
1391+
1392+
static ValueDecl *getAutoDiffProjectTopLevelSubcontext(ASTContext &ctx,
1393+
Identifier id) {
1394+
return getBuiltinFunction(
1395+
id, {ctx.TheNativeObjectType}, ctx.TheRawPointerType);
1396+
}
1397+
1398+
static ValueDecl *getAutoDiffAllocateSubcontext(ASTContext &ctx,
1399+
Identifier id) {
1400+
return getBuiltinFunction(
1401+
id, {ctx.TheNativeObjectType, BuiltinIntegerType::getWordType(ctx)},
1402+
ctx.TheRawPointerType);
1403+
}
1404+
13861405
static ValueDecl *getPoundAssert(ASTContext &Context, Identifier Id) {
13871406
auto int1Type = BuiltinIntegerType::get(1, Context);
13881407
auto optionalRawPointerType = BoundGenericEnumType::get(
@@ -2549,6 +2568,15 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
25492568

25502569
case BuiltinValueKind::TriggerFallbackDiagnostic:
25512570
return getTriggerFallbackDiagnosticOperation(Context, Id);
2571+
2572+
case BuiltinValueKind::AutoDiffCreateLinearMapContext:
2573+
return getAutoDiffCreateLinearMapContext(Context, Id);
2574+
2575+
case BuiltinValueKind::AutoDiffProjectTopLevelSubcontext:
2576+
return getAutoDiffProjectTopLevelSubcontext(Context, Id);
2577+
2578+
case BuiltinValueKind::AutoDiffAllocateSubcontext:
2579+
return getAutoDiffAllocateSubcontext(Context, Id);
25522580
}
25532581

25542582
llvm_unreachable("bad builtin value!");

0 commit comments

Comments
 (0)