Skip to content

Commit 3d6d186

Browse files
authored
Merge pull request #32767 from slavapestov/requestify-main-attr
Request-ify synthesis of the main function for the @main attribute
2 parents a405da4 + 7bc5088 commit 3d6d186

13 files changed

+190
-112
lines changed

include/swift/AST/ASTTypeIDZone.def

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ SWIFT_TYPEID_NAMED(ConstructorDecl *, ConstructorDecl)
4040
SWIFT_TYPEID_NAMED(CustomAttr *, CustomAttr)
4141
SWIFT_TYPEID_NAMED(Decl *, Decl)
4242
SWIFT_TYPEID_NAMED(EnumDecl *, EnumDecl)
43+
SWIFT_TYPEID_NAMED(FuncDecl *, FuncDecl)
4344
SWIFT_TYPEID_NAMED(GenericParamList *, GenericParamList)
4445
SWIFT_TYPEID_NAMED(GenericTypeParamType *, GenericTypeParamType)
4546
SWIFT_TYPEID_NAMED(InfixOperatorDecl *, InfixOperatorDecl)

include/swift/AST/ASTTypeIDs.h

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ConstructorDecl;
3030
class CustomAttr;
3131
class Decl;
3232
class EnumDecl;
33+
class FuncDecl;
3334
enum class FunctionBuilderBodyPreCheck : uint8_t;
3435
class GenericParamList;
3536
class GenericSignature;

include/swift/AST/TypeCheckRequests.h

+17
Original file line numberDiff line numberDiff line change
@@ -2564,6 +2564,23 @@ class CustomAttrTypeRequest
25642564
void cacheResult(Type value) const;
25652565
};
25662566

2567+
class SynthesizeMainFunctionRequest
2568+
: public SimpleRequest<SynthesizeMainFunctionRequest,
2569+
FuncDecl *(Decl *),
2570+
RequestFlags::Cached> {
2571+
public:
2572+
using SimpleRequest::SimpleRequest;
2573+
2574+
private:
2575+
friend SimpleRequest;
2576+
2577+
// Evaluation.
2578+
FuncDecl *evaluate(Evaluator &evaluator, Decl *) const;
2579+
2580+
public:
2581+
bool isCached() const { return true; }
2582+
};
2583+
25672584
// Allow AnyValue to compare two Type values, even though Type doesn't
25682585
// support ==.
25692586
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

+2
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,5 @@ SWIFT_REQUEST(TypeChecker, LookupAllConformancesInContextRequest,
274274
Uncached, NoLocationInfo)
275275
SWIFT_REQUEST(TypeChecker, SimpleDidSetRequest,
276276
bool(AccessorDecl *), Cached, NoLocationInfo)
277+
SWIFT_REQUEST(TypeChecker, SynthesizeMainFunctionRequest,
278+
FuncDecl *(Decl *), Cached, NoLocationInfo)

lib/Sema/TypeCheckAttr.cpp

+94-88
Original file line numberDiff line numberDiff line change
@@ -1782,7 +1782,76 @@ void AttributeChecker::visitUIApplicationMainAttr(UIApplicationMainAttr *attr) {
17821782
C.getIdentifier("UIApplicationMain"));
17831783
}
17841784

1785-
void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) {
1785+
namespace {
1786+
struct MainTypeAttrParams {
1787+
FuncDecl *mainFunction;
1788+
MainTypeAttr *attr;
1789+
};
1790+
1791+
}
1792+
static std::pair<BraceStmt *, bool>
1793+
synthesizeMainBody(AbstractFunctionDecl *fn, void *arg) {
1794+
ASTContext &context = fn->getASTContext();
1795+
MainTypeAttrParams *params = (MainTypeAttrParams *) arg;
1796+
1797+
FuncDecl *mainFunction = params->mainFunction;
1798+
auto location = params->attr->getLocation();
1799+
NominalTypeDecl *nominal = fn->getDeclContext()->getSelfNominalTypeDecl();
1800+
1801+
auto *typeExpr = TypeExpr::createImplicit(nominal->getDeclaredType(), context);
1802+
1803+
SubstitutionMap substitutionMap;
1804+
if (auto *environment = mainFunction->getGenericEnvironment()) {
1805+
substitutionMap = SubstitutionMap::get(
1806+
environment->getGenericSignature(),
1807+
[&](SubstitutableType *type) { return nominal->getDeclaredType(); },
1808+
LookUpConformanceInModule(nominal->getModuleContext()));
1809+
} else {
1810+
substitutionMap = SubstitutionMap();
1811+
}
1812+
1813+
auto funcDeclRef = ConcreteDeclRef(mainFunction, substitutionMap);
1814+
1815+
auto *memberRefExpr = new (context) MemberRefExpr(
1816+
typeExpr, SourceLoc(), funcDeclRef, DeclNameLoc(location),
1817+
/*Implicit*/ true);
1818+
memberRefExpr->setImplicit(true);
1819+
1820+
auto *callExpr = CallExpr::createImplicit(context, memberRefExpr, {}, {});
1821+
callExpr->setImplicit(true);
1822+
callExpr->setThrows(mainFunction->hasThrows());
1823+
callExpr->setType(context.TheEmptyTupleType);
1824+
1825+
Expr *returnedExpr;
1826+
1827+
if (mainFunction->hasThrows()) {
1828+
auto *tryExpr = new (context) TryExpr(
1829+
SourceLoc(), callExpr, context.TheEmptyTupleType, /*implicit=*/true);
1830+
returnedExpr = tryExpr;
1831+
} else {
1832+
returnedExpr = callExpr;
1833+
}
1834+
1835+
auto *returnStmt =
1836+
new (context) ReturnStmt(SourceLoc(), callExpr, /*Implicit=*/true);
1837+
1838+
SmallVector<ASTNode, 1> stmts;
1839+
stmts.push_back(returnStmt);
1840+
auto *body = BraceStmt::create(context, SourceLoc(), stmts,
1841+
SourceLoc(), /*Implicit*/true);
1842+
1843+
return std::make_pair(body, /*typechecked=*/false);
1844+
}
1845+
1846+
FuncDecl *
1847+
SynthesizeMainFunctionRequest::evaluate(Evaluator &evaluator,
1848+
Decl *D) const {
1849+
auto &context = D->getASTContext();
1850+
1851+
MainTypeAttr *attr = D->getAttrs().getAttribute<MainTypeAttr>();
1852+
if (attr == nullptr)
1853+
return nullptr;
1854+
17861855
auto *extension = dyn_cast<ExtensionDecl>(D);
17871856

17881857
IterableDeclContext *iterableDeclContext;
@@ -1802,25 +1871,19 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) {
18021871
braces = nominal->getBraces();
18031872
}
18041873

1805-
if (!nominal) {
1806-
assert(false && "Should have already recognized that the MainType decl "
1874+
assert(nominal && "Should have already recognized that the MainType decl "
18071875
"isn't applicable to decls other than NominalTypeDecls");
1808-
return;
1809-
}
18101876
assert(iterableDeclContext);
18111877
assert(declContext);
18121878

18131879
// The type cannot be generic.
18141880
if (nominal->isGenericContext()) {
1815-
diagnose(attr->getLocation(),
1816-
diag::attr_generic_ApplicationMain_not_supported, 2);
1881+
context.Diags.diagnose(attr->getLocation(),
1882+
diag::attr_generic_ApplicationMain_not_supported, 2);
18171883
attr->setInvalid();
1818-
return;
1884+
return nullptr;
18191885
}
18201886

1821-
SourceFile *file = cast<SourceFile>(declContext->getModuleScopeContext());
1822-
assert(file);
1823-
18241887
// Create a function
18251888
//
18261889
// func $main() {
@@ -1832,8 +1895,6 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) {
18321895
// usual type-checking. The alternative would be to directly call
18331896
// mainType.main() from the entry point, and that would require fully
18341897
// type-checking the call to mainType.main().
1835-
auto &context = D->getASTContext();
1836-
auto location = attr->getLocation();
18371898

18381899
auto resolution = resolveValueMember(
18391900
*declContext, nominal->getInterfaceType(), context.Id_main);
@@ -1861,107 +1922,52 @@ void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) {
18611922
}
18621923

18631924
if (viableCandidates.size() != 1) {
1864-
diagnose(attr->getLocation(), diag::attr_MainType_without_main,
1865-
nominal->getBaseName());
1925+
context.Diags.diagnose(attr->getLocation(),
1926+
diag::attr_MainType_without_main,
1927+
nominal->getBaseName());
18661928
attr->setInvalid();
1867-
return;
1929+
return nullptr;
18681930
}
18691931
mainFunction = viableCandidates[0];
18701932
}
18711933

1872-
bool mainFunctionThrows = mainFunction->hasThrows();
1873-
1874-
auto voidToVoidFunctionType =
1875-
FunctionType::get({}, context.TheEmptyTupleType,
1876-
FunctionType::ExtInfo().withThrows(mainFunctionThrows));
1877-
auto nominalToVoidToVoidFunctionType = FunctionType::get({AnyFunctionType::Param(nominal->getInterfaceType())}, voidToVoidFunctionType);
18781934
auto *func = FuncDecl::create(
18791935
context, /*StaticLoc*/ SourceLoc(), StaticSpellingKind::KeywordStatic,
18801936
/*FuncLoc*/ SourceLoc(),
18811937
DeclName(context, DeclBaseName(context.Id_MainEntryPoint),
18821938
ParameterList::createEmpty(context)),
1883-
/*NameLoc*/ SourceLoc(), /*Throws=*/mainFunctionThrows,
1939+
/*NameLoc*/ SourceLoc(), /*Throws=*/mainFunction->hasThrows(),
18841940
/*ThrowsLoc=*/SourceLoc(),
18851941
/*GenericParams=*/nullptr, ParameterList::createEmpty(context),
18861942
/*FnRetType=*/TypeLoc::withoutLoc(TupleType::getEmpty(context)),
18871943
declContext);
18881944
func->setImplicit(true);
18891945
func->setSynthesized(true);
18901946

1891-
auto *typeExpr = TypeExpr::createImplicit(nominal->getDeclaredType(), context);
1892-
1893-
SubstitutionMap substitutionMap;
1894-
if (auto *environment = mainFunction->getGenericEnvironment()) {
1895-
substitutionMap = SubstitutionMap::get(
1896-
environment->getGenericSignature(),
1897-
[&](SubstitutableType *type) { return nominal->getDeclaredType(); },
1898-
LookUpConformanceInModule(nominal->getModuleContext()));
1899-
} else {
1900-
substitutionMap = SubstitutionMap();
1901-
}
1902-
1903-
auto funcDeclRef = ConcreteDeclRef(mainFunction, substitutionMap);
1904-
1905-
auto *memberRefExpr = new (context) MemberRefExpr(
1906-
typeExpr, SourceLoc(), funcDeclRef, DeclNameLoc(location),
1907-
/*Implicit*/ true);
1908-
memberRefExpr->setImplicit(true);
1909-
1910-
auto *callExpr = CallExpr::createImplicit(context, memberRefExpr, {}, {});
1911-
callExpr->setImplicit(true);
1912-
callExpr->setThrows(mainFunctionThrows);
1913-
callExpr->setType(context.TheEmptyTupleType);
1914-
1915-
Expr *returnedExpr;
1947+
auto *params = context.Allocate<MainTypeAttrParams>();
1948+
params->mainFunction = mainFunction;
1949+
params->attr = attr;
1950+
func->setBodySynthesizer(synthesizeMainBody, params);
19161951

1917-
if (mainFunctionThrows) {
1918-
auto *tryExpr = new (context) TryExpr(
1919-
SourceLoc(), callExpr, context.TheEmptyTupleType, /*implicit=*/true);
1920-
returnedExpr = tryExpr;
1921-
} else {
1922-
returnedExpr = callExpr;
1923-
}
1952+
iterableDeclContext->addMember(func);
19241953

1925-
auto *returnStmt =
1926-
new (context) ReturnStmt(SourceLoc(), callExpr, /*Implicit=*/true);
1954+
return func;
1955+
}
19271956

1928-
SmallVector<ASTNode, 1> stmts;
1929-
stmts.push_back(returnStmt);
1930-
auto *body = BraceStmt::create(context, SourceLoc(), stmts,
1931-
SourceLoc(), /*Implicit*/true);
1932-
func->setBodyParsed(body);
1933-
func->setInterfaceType(nominalToVoidToVoidFunctionType);
1957+
void AttributeChecker::visitMainTypeAttr(MainTypeAttr *attr) {
1958+
auto &context = D->getASTContext();
19341959

1935-
iterableDeclContext->addMember(func);
1960+
SourceFile *file = D->getDeclContext()->getParentSourceFile();
1961+
assert(file);
19361962

1937-
// This function must be type-checked. Why? Consider the following scenario:
1938-
//
1939-
// protocol AlmostMainable {}
1940-
// protocol ReallyMainable {}
1941-
// extension AlmostMainable where Self : ReallyMainable {
1942-
// static func main() {}
1943-
// }
1944-
// @main struct Main : AlmostMainable {}
1945-
//
1946-
// Note in particular that Main does not conform to ReallyMainable.
1947-
//
1948-
// In this case, resolveValueMember will find the function main in the
1949-
// extension, and so, since there is one candidate, the function $main will
1950-
// accordingly be formed as usual:
1951-
//
1952-
// func $main() {
1953-
// return Main.main()
1954-
// }
1955-
//
1956-
// Of course, this function's body does not type-check.
1957-
file->DelayedFunctions.push_back(func);
1963+
auto *func = evaluateOrDefault(context.evaluator,
1964+
SynthesizeMainFunctionRequest{D},
1965+
nullptr);
19581966

19591967
// Register the func as the main decl in the module. If there are multiples
19601968
// they will be diagnosed.
1961-
if (file->registerMainDecl(func, attr->getLocation())) {
1969+
if (file->registerMainDecl(func, attr->getLocation()))
19621970
attr->setInvalid();
1963-
return;
1964-
}
19651971
}
19661972

19671973
/// Determine whether the given context is an extension to an Objective-C class

lib/Sema/TypeCheckDecl.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -2523,6 +2523,12 @@ EmittedMembersRequest::evaluate(Evaluator &evaluator,
25232523
forceConformance(Context.getProtocol(KnownProtocolKind::Hashable));
25242524
forceConformance(Context.getProtocol(KnownProtocolKind::Differentiable));
25252525

2526+
// If the class has a @main attribute, we need to force synthesis of the
2527+
// $main function.
2528+
(void) evaluateOrDefault(Context.evaluator,
2529+
SynthesizeMainFunctionRequest{CD},
2530+
nullptr);
2531+
25262532
for (auto *member : CD->getMembers()) {
25272533
if (auto *var = dyn_cast<VarDecl>(member)) {
25282534
// The projected storage wrapper ($foo) might have dynamically-dispatched

lib/Sema/TypeCheckDeclPrimary.cpp

+11-10
Original file line numberDiff line numberDiff line change
@@ -1797,13 +1797,13 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
17971797
checkGenericParams(ED);
17981798

17991799
// Check for circular inheritance of the raw type.
1800-
(void)ED->hasCircularRawValue();
1800+
(void) ED->hasCircularRawValue();
1801+
1802+
TypeChecker::checkDeclAttributes(ED);
18011803

18021804
for (Decl *member : ED->getMembers())
18031805
visit(member);
18041806

1805-
TypeChecker::checkDeclAttributes(ED);
1806-
18071807
checkInheritanceClause(ED);
18081808

18091809
checkAccessControl(ED);
@@ -1845,13 +1845,13 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
18451845

18461846
installCodingKeysIfNecessary(SD);
18471847

1848+
TypeChecker::checkDeclAttributes(SD);
1849+
18481850
for (Decl *Member : SD->getMembers())
18491851
visit(Member);
18501852

18511853
TypeChecker::checkPatternBindingCaptures(SD);
18521854

1853-
TypeChecker::checkDeclAttributes(SD);
1854-
18551855
checkInheritanceClause(SD);
18561856

18571857
checkAccessControl(SD);
@@ -1974,6 +1974,8 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
19741974
// Force creation of an implicit destructor, if any.
19751975
(void) CD->getDestructor();
19761976

1977+
TypeChecker::checkDeclAttributes(CD);
1978+
19771979
for (Decl *Member : CD->getEmittedMembers())
19781980
visit(Member);
19791981

@@ -2084,8 +2086,6 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
20842086
}
20852087
}
20862088

2087-
TypeChecker::checkDeclAttributes(CD);
2088-
20892089
checkInheritanceClause(CD);
20902090

20912091
checkAccessControl(CD);
@@ -2105,12 +2105,12 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
21052105
// Check for circular inheritance within the protocol.
21062106
(void)PD->hasCircularInheritedProtocols();
21072107

2108+
TypeChecker::checkDeclAttributes(PD);
2109+
21082110
// Check the members.
21092111
for (auto Member : PD->getMembers())
21102112
visit(Member);
21112113

2112-
TypeChecker::checkDeclAttributes(PD);
2113-
21142114
checkAccessControl(PD);
21152115

21162116
checkInheritanceClause(PD);
@@ -2428,14 +2428,15 @@ class DeclChecker : public DeclVisitor<DeclChecker> {
24282428

24292429
checkGenericParams(ED);
24302430

2431+
TypeChecker::checkDeclAttributes(ED);
2432+
24312433
for (Decl *Member : ED->getMembers())
24322434
visit(Member);
24332435

24342436
TypeChecker::checkPatternBindingCaptures(ED);
24352437

24362438
TypeChecker::checkConformancesInContext(ED);
24372439

2438-
TypeChecker::checkDeclAttributes(ED);
24392440
checkAccessControl(ED);
24402441

24412442
checkExplicitAvailability(ED);

0 commit comments

Comments
 (0)