Skip to content

Commit 56ad5ce

Browse files
committed
Generate try/await thunks programmatically instead of with string parsing
1 parent 888e063 commit 56ad5ce

File tree

6 files changed

+155
-68
lines changed

6 files changed

+155
-68
lines changed

Sources/TestingMacros/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ target_sources(TestingMacros PRIVATE
101101
Support/CRC32.swift
102102
Support/DiagnosticMessage.swift
103103
Support/DiagnosticMessage+Diagnosing.swift
104+
Support/EffectfulExpressionHandling.swift
104105
Support/SourceLocationGeneration.swift
105106
TagMacro.swift
106107
TestDeclarationMacro.swift

Sources/TestingMacros/ConditionMacro.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ extension ExitTestConditionMacro {
456456
decls.append(
457457
"""
458458
@Sendable func \(bodyThunkName)() async throws -> Void {
459-
return try await Testing.__requiringTry(Testing.__requiringAwait(\(bodyArgumentExpr.trimmed)))()
459+
return \(applyEffectfulKeywords([.try, .await], to: bodyArgumentExpr))()
460460
}
461461
"""
462462
)

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

+5-59
Original file line numberDiff line numberDiff line change
@@ -573,15 +573,12 @@ extension ConditionMacro {
573573
var expandedExpr = contextInserter.rewrite(node).cast(ExprSyntax.self)
574574
let rewrittenNodes = contextInserter.rewrittenNodes
575575

576-
// Insert additional effect keywords as needed. Use the helper functions so
577-
// we don't need to worry about the precise structure of the expression
578-
// being rewritten.
579-
if effectKeywordsToApply.contains(.await) {
580-
expandedExpr = "await Testing.__requiringAwait(\(expandedExpr))"
581-
}
582-
if isThrowing || effectKeywordsToApply.contains(.try) {
583-
expandedExpr = "try Testing.__requiringTry(\(expandedExpr))"
576+
// Insert additional effect keywords/thunks as needed.
577+
var effectKeywordsToApply = effectKeywordsToApply
578+
if isThrowing {
579+
effectKeywordsToApply.insert(.try)
584580
}
581+
expandedExpr = applyEffectfulKeywords(effectKeywordsToApply, to: expandedExpr)
585582

586583
// Construct the body of the closure that we'll pass to the expanded
587584
// function.
@@ -675,57 +672,6 @@ private final class _OptionalChainFinder: SyntaxVisitor {
675672
}
676673
}
677674

678-
// MARK: - Finding effect keywords
679-
680-
/// A syntax visitor class that looks for effectful keywords in a given
681-
/// expression.
682-
private final class _EffectFinder: SyntaxAnyVisitor {
683-
/// The effect keywords discovered so far.
684-
var effectKeywords: Set<Keyword> = []
685-
686-
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
687-
switch node.kind {
688-
case .tryExpr:
689-
effectKeywords.insert(.try)
690-
case .awaitExpr:
691-
effectKeywords.insert(.await)
692-
case .consumeExpr:
693-
effectKeywords.insert(.consume)
694-
case .closureExpr, .functionDecl:
695-
// Do not delve into closures or function declarations.
696-
return .skipChildren
697-
case .variableDecl:
698-
// Delve into variable declarations.
699-
return .visitChildren
700-
default:
701-
// Do not delve into declarations other than variables.
702-
if node.isProtocol((any DeclSyntaxProtocol).self) {
703-
return .skipChildren
704-
}
705-
}
706-
707-
// Recurse into everything else.
708-
return .visitChildren
709-
}
710-
}
711-
712-
/// Find effectful keywords in a syntax node.
713-
///
714-
/// - Parameters:
715-
/// - node: The node to inspect.
716-
///
717-
/// - Returns: A set of effectful keywords such as `await` that are present in
718-
/// `node`.
719-
///
720-
/// This function does not descend into function declarations or closure
721-
/// expressions because they represent distinct lexical contexts and their
722-
/// effects are uninteresting in the context of `node` unless they are called.
723-
func findEffectKeywords(in node: some SyntaxProtocol) -> Set<Keyword> {
724-
let effectFinder = _EffectFinder(viewMode: .sourceAccurate)
725-
effectFinder.walk(node)
726-
return effectFinder.effectKeywords
727-
}
728-
729675
// MARK: - Replacing dollar identifiers
730676

731677
/// Rewrite a dollar identifier as a normal (non-dollar) identifier.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
//
2+
// This source file is part of the Swift.org open source project
3+
//
4+
// Copyright (c) 2024 Apple Inc. and the Swift project authors
5+
// Licensed under Apache License v2.0 with Runtime Library Exception
6+
//
7+
// See https://swift.org/LICENSE.txt for license information
8+
// See https://swift.org/CONTRIBUTORS.txt for Swift project authors
9+
//
10+
11+
import SwiftSyntax
12+
13+
// MARK: - Finding effect keywords
14+
15+
/// A syntax visitor class that looks for effectful keywords in a given
16+
/// expression.
17+
private final class _EffectFinder: SyntaxAnyVisitor {
18+
/// The effect keywords discovered so far.
19+
var effectKeywords: Set<Keyword> = []
20+
21+
override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
22+
switch node.kind {
23+
case .tryExpr:
24+
effectKeywords.insert(.try)
25+
case .awaitExpr:
26+
effectKeywords.insert(.await)
27+
case .consumeExpr:
28+
effectKeywords.insert(.consume)
29+
case .closureExpr, .functionDecl:
30+
// Do not delve into closures or function declarations.
31+
return .skipChildren
32+
case .variableDecl:
33+
// Delve into variable declarations.
34+
return .visitChildren
35+
default:
36+
// Do not delve into declarations other than variables.
37+
if node.isProtocol((any DeclSyntaxProtocol).self) {
38+
return .skipChildren
39+
}
40+
}
41+
42+
// Recurse into everything else.
43+
return .visitChildren
44+
}
45+
}
46+
47+
/// Find effectful keywords in a syntax node.
48+
///
49+
/// - Parameters:
50+
/// - node: The node to inspect.
51+
///
52+
/// - Returns: A set of effectful keywords such as `await` that are present in
53+
/// `node`.
54+
///
55+
/// This function does not descend into function declarations or closure
56+
/// expressions because they represent distinct lexical contexts and their
57+
/// effects are uninteresting in the context of `node` unless they are called.
58+
func findEffectKeywords(in node: some SyntaxProtocol) -> Set<Keyword> {
59+
let effectFinder = _EffectFinder(viewMode: .sourceAccurate)
60+
effectFinder.walk(node)
61+
return effectFinder.effectKeywords
62+
}
63+
64+
// MARK: - Inserting effect keywords/thunks
65+
66+
/// Make a function call expression to an effectful thunk function provided by
67+
/// the testing library.
68+
///
69+
/// - Parameters:
70+
/// - thunkName: The unqualified name of the thunk function to call. This
71+
/// token must be the name of a function in the `Testing` module.
72+
/// - expr: The expression to thunk.
73+
///
74+
/// - Returns: An expression representing a call to the function named
75+
/// `thunkName`, passing `expr`.
76+
private func _makeCallToEffectfulThunk(_ thunkName: TokenSyntax, passing expr: some ExprSyntaxProtocol) -> ExprSyntax {
77+
var result = FunctionCallExprSyntax(
78+
calledExpression: MemberAccessExprSyntax(
79+
base: DeclReferenceExprSyntax(baseName: .identifier("Testing")),
80+
declName: DeclReferenceExprSyntax(baseName: thunkName)
81+
)
82+
) {
83+
LabeledExprSyntax(expression: expr.trimmed)
84+
}
85+
86+
result.leftParen = .leftParenToken()
87+
result.rightParen = .rightParenToken()
88+
89+
return ExprSyntax(result)
90+
}
91+
92+
/// Apply the given effectful keywords (i.e. `try` and `await`) to an expression
93+
/// using thunk functions provided by the testing library.
94+
///
95+
/// - Parameters:
96+
/// - effectfulKeywords: The effectful keywords to apply.
97+
/// - expr: The expression to apply the keywords and thunk functions to.
98+
///
99+
/// - Returns: A copy of `expr` if no changes are needed, or an expression that
100+
/// adds the keywords in `effectfulKeywords` to `expr`.
101+
func applyEffectfulKeywords(_ effectfulKeywords: Set<Keyword>, to expr: some ExprSyntaxProtocol) -> ExprSyntax {
102+
let originalExpr = expr
103+
var expr = ExprSyntax(expr)
104+
105+
let needAwait = effectfulKeywords.contains(.await) && !expr.is(AwaitExprSyntax.self)
106+
let needTry = effectfulKeywords.contains(.try) && !expr.is(TryExprSyntax.self)
107+
108+
// First, add thunk function calls.
109+
if needAwait {
110+
expr = _makeCallToEffectfulThunk(.identifier("__requiringAwait"), passing: expr)
111+
}
112+
if needTry {
113+
expr = _makeCallToEffectfulThunk(.identifier("__requiringTry"), passing: expr)
114+
}
115+
116+
// Then add keyword expressions. (We do this separately so we end up writing
117+
// `try await __r(__r(self))` instead of `try __r(await __r(self))` which
118+
// is less accepted by the compiler.
119+
if needAwait {
120+
expr = ExprSyntax(
121+
AwaitExprSyntax(
122+
awaitKeyword: .keyword(.await).with(\.trailingTrivia, .space),
123+
expression: expr
124+
)
125+
)
126+
}
127+
if needTry {
128+
expr = ExprSyntax(
129+
TryExprSyntax(
130+
tryKeyword: .keyword(.try).with(\.trailingTrivia, .space),
131+
expression: expr
132+
)
133+
)
134+
}
135+
136+
expr.leadingTrivia = originalExpr.leadingTrivia
137+
expr.trailingTrivia = originalExpr.trailingTrivia
138+
139+
return expr
140+
}

Sources/TestingMacros/TestDeclarationMacro.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,17 @@ public struct TestDeclarationMacro: PeerMacro, Sendable {
235235
// detecting isolation to other global actors.
236236
lazy var isMainActorIsolated = !functionDecl.attributes(named: "MainActor", inModuleNamed: "_Concurrency").isEmpty
237237
var forwardCall: (ExprSyntax) -> ExprSyntax = {
238-
"try await Testing.__requiringTry(Testing.__requiringAwait(\($0)))"
238+
applyEffectfulKeywords([.try, .await], to: $0)
239239
}
240240
let forwardInit = forwardCall
241241
if functionDecl.noasyncAttribute != nil {
242242
if isMainActorIsolated {
243243
forwardCall = {
244-
"try await MainActor.run { try Testing.__requiringTry(\($0)) }"
244+
"try await MainActor.run { \(applyEffectfulKeywords([.try], to: $0)) }"
245245
}
246246
} else {
247247
forwardCall = {
248-
"try { try Testing.__requiringTry(\($0)) }()"
248+
"try { \(applyEffectfulKeywords([.try], to: $0)) }()"
249249
}
250250
}
251251
}

Tests/TestingMacrosTests/ConditionMacroTests.swift

+5-5
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct ConditionMacroTests {
4141
##"#expect((true && false))"##:
4242
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in __ec((__ec(__ec(true, "3c") && __ec(false, "21c"), "1c")), "") }, sourceCode: ["": "(true && false)", "1c": "true && false", "3c": "true", "21c": "false"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
4343
##"#expect(try x())"##:
44-
##"try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(try __ec(x(), "4")) }, sourceCode: ["4": "x()"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
44+
##"try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try __ec(x(), "4") }, sourceCode: ["4": "x()"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
4545
##"#expect(1 is Int)"##:
4646
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in __ec.__is(1, "", (Int).self, "10") }, sourceCode: ["": "1 is Int", "10": "Int"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
4747
##"#expect("123") { 1 == 2 } then: { foo() }"##:
@@ -122,7 +122,7 @@ struct ConditionMacroTests {
122122
##"#require((true && false))"##:
123123
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(__ec((__ec(__ec(true, "3c") && __ec(false, "21c"), "1c")), "")) }, sourceCode: ["": "(true && false)", "1c": "true && false", "3c": "true", "21c": "false"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
124124
##"#require(try x())"##:
125-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(try __ec(x(), "4")) }, sourceCode: ["4": "x()"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
125+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try __ec(x(), "4") }, sourceCode: ["4": "x()"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
126126
##"#require(1 is Int)"##:
127127
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(__ec.__is(1, "", (Int).self, "10")) }, sourceCode: ["": "1 is Int", "10": "Int"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
128128
##"#require("123") { 1 == 2 } then: { foo() }"##:
@@ -215,7 +215,7 @@ struct ConditionMacroTests {
215215
"""
216216
// Source comment
217217
/** Doc comment */
218-
try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(try __ec(x(), "4")) }, sourceCode: ["4": "x()"], comments: [.__line("// Source comment"),.__documentationBlock("/** Doc comment */"),"Argument comment"], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()
218+
try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try __ec(x(), "4") }, sourceCode: ["4": "x()"], comments: [.__line("// Source comment"),.__documentationBlock("/** Doc comment */"),"Argument comment"], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()
219219
""",
220220
221221
"""
@@ -228,7 +228,7 @@ struct ConditionMacroTests {
228228
// Ignore me
229229
230230
// Capture me
231-
try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(try __ec(x(), "4")) }, sourceCode: ["4": "x()"], comments: [.__line("// Capture me")], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()
231+
try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try __ec(x(), "4") }, sourceCode: ["4": "x()"], comments: [.__line("// Capture me")], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()
232232
""",
233233
234234
"""
@@ -241,7 +241,7 @@ struct ConditionMacroTests {
241241
// Ignore me
242242
\t
243243
// Capture me
244-
try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(try __ec(x(), "4")) }, sourceCode: ["4": "x()"], comments: [.__line("// Capture me")], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()
244+
try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try __ec(x(), "4") }, sourceCode: ["4": "x()"], comments: [.__line("// Capture me")], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()
245245
""",
246246
]
247247
)

0 commit comments

Comments
 (0)