Skip to content

[Macros] Implement support for function body macros on closures. #79980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 21, 2025
10 changes: 9 additions & 1 deletion include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,18 @@ class ASTMangler : public Mangler {
std::string mangleMacroExpansion(const FreestandingMacroExpansion *expansion);
std::string mangleAttachedMacroExpansion(
const Decl *decl, CustomAttr *attr, MacroRole role);
std::string mangleAttachedMacroExpansion(
ClosureExpr *attachedTo, CustomAttr *attr, MacroDecl *macro);

void appendMacroExpansion(const FreestandingMacroExpansion *expansion);
void appendMacroExpansionContext(SourceLoc loc, DeclContext *origDC,
const FreestandingMacroExpansion *expansion);
Identifier macroName,
unsigned discriminator);

void appendMacroExpansion(ClosureExpr *attachedTo,
CustomAttr *attr,
MacroDecl *macro);

void appendMacroExpansionOperator(
StringRef macroName, MacroRole role, unsigned discriminator);

Expand Down
48 changes: 48 additions & 0 deletions include/swift/AST/AnyFunctionRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,54 @@ class AnyFunctionRef {
llvm_unreachable("unexpected AnyFunctionRef representation");
}

DeclAttributes getDeclAttributes() const {
if (auto afd = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
return afd->getExpandedAttrs();
}

if (auto ace = TheFunction.dyn_cast<AbstractClosureExpr *>()) {
if (auto *ce = dyn_cast<ClosureExpr>(ace)) {
return ce->getAttrs();
}
}

return DeclAttributes();
}

MacroDecl *getResolvedMacro(CustomAttr *attr) const {
if (auto afd = TheFunction.dyn_cast<AbstractFunctionDecl *>()) {
return afd->getResolvedMacro(attr);
}

if (auto ace = TheFunction.dyn_cast<AbstractClosureExpr *>()) {
if (auto *ce = dyn_cast<ClosureExpr>(ace)) {
return ce->getResolvedMacro(attr);
}
}

return nullptr;
}

using MacroCallback = llvm::function_ref<void(CustomAttr *, MacroDecl *)>;

void
forEachAttachedMacro(MacroRole role,
MacroCallback macroCallback) const {
auto attrs = getDeclAttributes();
for (auto customAttrConst : attrs.getAttributes<CustomAttr>()) {
auto customAttr = const_cast<CustomAttr *>(customAttrConst);
auto *macroDecl = getResolvedMacro(customAttr);

if (!macroDecl)
continue;

if (!macroDecl->getMacroRoles().contains(role))
continue;

macroCallback(customAttr, macroDecl);
}
}

friend bool operator==(AnyFunctionRef lhs, AnyFunctionRef rhs) {
return lhs.TheFunction == rhs.TheFunction;
}
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -7833,6 +7833,9 @@ ERROR(conformance_macro,none,
ERROR(experimental_macro,none,
"macro %0 is experimental",
(DeclName))
ERROR(experimental_closure_body_macro,none,
"function body macros on closures is experimental",
())

ERROR(macro_resolve_circular_reference, none,
"circular reference resolving %select{freestanding|attached}0 macro %1",
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -4305,6 +4305,8 @@ class ClosureExpr : public AbstractClosureExpr {
BraceStmt *getBody() const { return Body; }
void setBody(BraceStmt *S) { Body = S; }

BraceStmt *getExpandedBody();

DeclAttributes &getAttrs() { return Attributes; }
const DeclAttributes &getAttrs() const { return Attributes; }

Expand Down Expand Up @@ -4422,6 +4424,10 @@ class ClosureExpr : public AbstractClosureExpr {
return ExplicitResultTypeAndBodyState.getPointer()->getTypeRepr();
}

/// Returns the resolved macro for the given custom attribute
/// attached to this closure expression.
MacroDecl *getResolvedMacro(CustomAttr *customAttr);

/// Determine whether the closure has a single expression for its
/// body.
///
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -4708,7 +4708,7 @@ class ExpandPreambleMacroRequest

class ExpandBodyMacroRequest
: public SimpleRequest<ExpandBodyMacroRequest,
std::optional<unsigned>(AbstractFunctionDecl *),
std::optional<unsigned>(AnyFunctionRef),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;
Expand All @@ -4717,7 +4717,7 @@ class ExpandBodyMacroRequest
friend SimpleRequest;

std::optional<unsigned> evaluate(Evaluator &evaluator,
AbstractFunctionDecl *fn) const;
AnyFunctionRef fn) const;

public:
bool isCached() const { return true; }
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ SWIFT_REQUEST(TypeChecker, ExpandPreambleMacroRequest,
ArrayRef<unsigned>(AbstractFunctionDecl *),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ExpandBodyMacroRequest,
std::optional<unsigned>(AbstractFunctionDecl *),
std::optional<unsigned>(AnyFunctionRef),
Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, LocalDiscriminatorsRequest,
unsigned(DeclContext *),
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Basic/Features.def
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ EXPERIMENTAL_FEATURE(ConcurrencySyntaxSugar, true)
/// Allow declaration of compile-time values
EXPERIMENTAL_FEATURE(CompileTimeValues, true)

/// Allow function body macros applied to closures.
EXPERIMENTAL_FEATURE(ClosureBodyMacro, true)

#undef EXPERIMENTAL_FEATURE_EXCLUDED_FROM_MODULE_INTERFACE
#undef EXPERIMENTAL_FEATURE
#undef UPCOMING_FEATURE
Expand Down
66 changes: 56 additions & 10 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4759,7 +4759,8 @@ static Identifier encodeLocalPrecheckedDiscriminator(

void ASTMangler::appendMacroExpansionContext(
SourceLoc loc, DeclContext *origDC,
const FreestandingMacroExpansion *expansion
Identifier macroName,
unsigned macroDiscriminator
) {
origDC = MacroDiscriminatorContext::getInnermostMacroContext(origDC);
BaseEntitySignature nullBase(nullptr);
Expand All @@ -4768,9 +4769,9 @@ void ASTMangler::appendMacroExpansionContext(
if (auto outermostLocalDC = getOutermostLocalContext(origDC)) {
auto innermostNonlocalDC = outermostLocalDC->getParent();
appendContext(innermostNonlocalDC, nullBase, StringRef());
Identifier name = expansion->getMacroName().getBaseIdentifier();
Identifier name = macroName;
ASTContext &ctx = origDC->getASTContext();
unsigned discriminator = expansion->getDiscriminator();
unsigned discriminator = macroDiscriminator;
name = encodeLocalPrecheckedDiscriminator(ctx, name, discriminator);
appendIdentifier(name.str());
} else {
Expand Down Expand Up @@ -4875,7 +4876,10 @@ void ASTMangler::appendMacroExpansionContext(
return appendMacroExpansionLoc();

// Append our own context and discriminator.
appendMacroExpansionContext(outerExpansionLoc, origDC, expansion);
appendMacroExpansionContext(
outerExpansionLoc, origDC,
macroName,
macroDiscriminator);
appendMacroExpansionOperator(
baseName.userFacingName(), role, discriminator);
}
Expand All @@ -4902,16 +4906,14 @@ void ASTMangler::appendMacroExpansionOperator(
}

static StringRef getPrivateDiscriminatorIfNecessary(
const MacroExpansionExpr *expansion) {
auto dc = MacroDiscriminatorContext::getInnermostMacroContext(
expansion->getDeclContext());
auto decl = dc->getAsDecl();
const DeclContext *macroDC) {
auto decl = macroDC->getAsDecl();
if (decl && !decl->isOutermostPrivateOrFilePrivateScope())
return StringRef();

// Mangle non-local private declarations with a textual discriminator
// based on their enclosing file.
auto topLevelSubcontext = dc->getModuleScopeContext();
auto topLevelSubcontext = macroDC->getModuleScopeContext();
SourceFile *sf = dyn_cast<SourceFile>(topLevelSubcontext);
if (!sf)
return StringRef();
Expand All @@ -4927,6 +4929,13 @@ static StringRef getPrivateDiscriminatorIfNecessary(
return discriminator.str();
}

static StringRef getPrivateDiscriminatorIfNecessary(
const MacroExpansionExpr *expansion) {
auto dc = MacroDiscriminatorContext::getInnermostMacroContext(
expansion->getDeclContext());
return getPrivateDiscriminatorIfNecessary(dc);
}

static StringRef getPrivateDiscriminatorIfNecessary(
const FreestandingMacroExpansion *expansion) {
switch (expansion->getFreestandingMacroKind()) {
Expand All @@ -4943,7 +4952,8 @@ void
ASTMangler::appendMacroExpansion(const FreestandingMacroExpansion *expansion) {
appendMacroExpansionContext(expansion->getPoundLoc(),
expansion->getDeclContext(),
expansion);
expansion->getMacroName().getBaseIdentifier(),
expansion->getDiscriminator());
auto privateDiscriminator = getPrivateDiscriminatorIfNecessary(expansion);
if (!privateDiscriminator.empty()) {
appendIdentifier(privateDiscriminator);
Expand All @@ -4955,6 +4965,42 @@ ASTMangler::appendMacroExpansion(const FreestandingMacroExpansion *expansion) {
expansion->getDiscriminator());
}

void ASTMangler::appendMacroExpansion(ClosureExpr *attachedTo,
CustomAttr *attr,
MacroDecl *macro) {
auto &ctx = attachedTo->getASTContext();
auto discriminator =
ctx.getNextMacroDiscriminator(attachedTo,
macro->getBaseName());

appendMacroExpansionContext(
attr->getLocation(),
attachedTo,
macro->getBaseName().getIdentifier(),
discriminator);

auto privateDiscriminator =
getPrivateDiscriminatorIfNecessary(attachedTo);
if (!privateDiscriminator.empty()) {
appendIdentifier(privateDiscriminator);
appendOperator("Ll");
}

appendMacroExpansionOperator(
macro->getBaseName().userFacingName(),
MacroRole::Body,
discriminator);
}

std::string
ASTMangler::mangleAttachedMacroExpansion(ClosureExpr *attachedTo,
CustomAttr *attr,
MacroDecl *macro) {
beginMangling();
appendMacroExpansion(attachedTo, attr, macro);
return finalize();
}

std::string
ASTMangler::mangleMacroExpansion(const FreestandingMacroExpansion *expansion) {
beginMangling();
Expand Down
21 changes: 15 additions & 6 deletions lib/AST/ASTScopeCreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,22 @@ ASTSourceFileScope::ASTSourceFileScope(SourceFile *SF,
break;
}
case MacroRole::Body: {
// Use the end location of the function decl itself as the parentLoc
// for the new function body scope. This is different from the end
// location of the original source range, which is after the end of the
// function decl.
auto expansion = SF->getMacroExpansion();
parentLoc = expansion.getEndLoc();
bodyForDecl = cast<AbstractFunctionDecl>(expansion.get<Decl *>());
if (expansion.is<Decl *>()) {
// Use the end location of the function decl itself as the parentLoc
// for the new function body scope. This is different from the end
// location of the original source range, which is after the end of the
// function decl.
bodyForDecl = cast<AbstractFunctionDecl>(expansion.get<Decl *>());
parentLoc = expansion.getEndLoc();
break;
}

// Otherwise, we have a closure body macro.
auto insertionRange = SF->getMacroInsertionRange();
parentLoc = insertionRange.End;
if (insertionRange.Start != insertionRange.End)
parentLoc = parentLoc.getAdvancedLoc(-1);
break;
}
}
Expand Down
35 changes: 35 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "swift/Basic/Assertions.h"
#include "swift/Basic/Statistic.h"
#include "swift/Basic/Unicode.h"
#include "swift/Basic/SourceManager.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/ASTVisitor.h"
#include "swift/AST/Decl.h" // FIXME: Bad dependency
Expand Down Expand Up @@ -2016,6 +2017,29 @@ BraceStmt * AbstractClosureExpr::getBody() const {
llvm_unreachable("Unknown closure expression");
}

BraceStmt *ClosureExpr::getExpandedBody() {
auto &ctx = getASTContext();

// Expand a body macro, if there is one.
BraceStmt *macroExpandedBody = nullptr;
if (auto bufferID = evaluateOrDefault(
ctx.evaluator,
ExpandBodyMacroRequest{this},
std::nullopt)) {
CharSourceRange bufferRange = ctx.SourceMgr.getRangeForBuffer(*bufferID);
auto bufferStart = bufferRange.getStart();
auto module = getParentModule();
auto macroSourceFile = module->getSourceFileContainingLocation(bufferStart);

if (macroSourceFile->getTopLevelItems().size() == 1) {
auto stmt = macroSourceFile->getTopLevelItems()[0].dyn_cast<Stmt *>();
macroExpandedBody = dyn_cast<BraceStmt>(stmt);
}
}

return macroExpandedBody;
}

bool AbstractClosureExpr::bodyHasExplicitReturnStmt() const {
return AnyFunctionRef(const_cast<AbstractClosureExpr *>(this))
.bodyHasExplicitReturnStmt();
Expand Down Expand Up @@ -2175,6 +2199,17 @@ void ClosureExpr::setExplicitResultType(Type ty) {
->setType(MetatypeType::get(ty));
}

MacroDecl *
ClosureExpr::getResolvedMacro(CustomAttr *customAttr) {
auto &ctx = getASTContext();
auto declRef = evaluateOrDefault(
ctx.evaluator,
ResolveMacroRequest{customAttr, this},
ConcreteDeclRef());

return dyn_cast_or_null<MacroDecl>(declRef.getDecl());
}

FORWARD_SOURCE_LOCS_TO(AutoClosureExpr, Body)

void AutoClosureExpr::setBody(Expr *E) {
Expand Down
4 changes: 4 additions & 0 deletions lib/AST/FeatureSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ static bool usesFeatureCompileTimeValues(Decl *decl) {
return decl->getAttrs().hasAttribute<ConstValAttr>();
}

static bool usesFeatureClosureBodyMacro(Decl *decl) {
return false;
}

static bool usesFeatureMemorySafetyAttributes(Decl *decl) {
if (decl->getAttrs().hasAttribute<SafeAttr>() ||
decl->getAttrs().hasAttribute<UnsafeAttr>())
Expand Down
18 changes: 8 additions & 10 deletions lib/AST/TypeCheckRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2681,19 +2681,17 @@ void ExpandPreambleMacroRequest::noteCycleStep(DiagnosticEngine &diags) const {
//----------------------------------------------------------------------------//

void ExpandBodyMacroRequest::diagnoseCycle(DiagnosticEngine &diags) const {
auto decl = std::get<0>(getStorage());
diags.diagnose(decl->getLoc(),
diag::macro_expand_circular_reference_entity,
"body",
decl->getName());
auto fn = std::get<0>(getStorage());
diags.diagnose(fn.getLoc(),
diag::macro_expand_circular_reference_unnamed,
"body");
}

void ExpandBodyMacroRequest::noteCycleStep(DiagnosticEngine &diags) const {
auto decl = std::get<0>(getStorage());
diags.diagnose(decl->getLoc(),
diag::macro_expand_circular_reference_entity_through,
"body",
decl->getName());
auto fn = std::get<0>(getStorage());
diags.diagnose(fn.getLoc(),
diag::macro_expand_circular_reference_unnamed_through,
"body");
}

//----------------------------------------------------------------------------//
Expand Down
Loading