Skip to content

Commit bde7848

Browse files
committed
Capture x instead of &x as source for inout args, ensure dollar identifiers are rewritten as late as possible
1 parent 662ef78 commit bde7848

File tree

4 files changed

+97
-50
lines changed

4 files changed

+97
-50
lines changed

Sources/TestingMacros/ConditionMacro.swift

+61-20
Original file line numberDiff line numberDiff line change
@@ -189,30 +189,71 @@ extension ConditionMacro {
189189
argumentExpr = "try Testing.__requiringTry(\(argumentExpr))"
190190
}
191191

192-
// Replace any dollar identifiers we find.
193-
let closureArguments = rewriteClosureArguments(in: argumentExpr)
194-
if let closureArguments {
195-
argumentExpr = closureArguments.rewrittenNode.cast(ExprSyntax.self)
196-
}
197-
198-
// If we're inserting any additional code into the closure before the
199-
// rewritten argument, we can't elide the return keyword for brevity.
200-
var returnKeyword: TokenSyntax?
201-
if !prefixCodeBlockItems.isEmpty {
202-
returnKeyword = .keyword(.return)
203-
.with(\.leadingTrivia, argumentExpr.leadingTrivia)
204-
argumentExpr.leadingTrivia = .space
192+
// Construct the body of the closure that we'll pass to the expanded
193+
// function.
194+
var codeBlockItems = CodeBlockItemListSyntax {
195+
if prefixCodeBlockItems.isEmpty {
196+
CodeBlockItemSyntax(item: .expr(argumentExpr))
197+
.with(\.trailingTrivia, .newline)
198+
} else {
199+
prefixCodeBlockItems
200+
201+
// If we're inserting any additional code into the closure before
202+
// the rewritten argument, we can't elide the return keyword.
203+
CodeBlockItemSyntax(
204+
item: .stmt(
205+
StmtSyntax(
206+
ReturnStmtSyntax(
207+
expression: argumentExpr
208+
.with(\.leadingTrivia, .space)
209+
)
210+
)
211+
)
212+
).with(\.trailingTrivia, .newline)
213+
}
205214
}
206215

207-
// Enclose the expression in a closure into which we pass our local
208-
// context object.
209-
argumentExpr = """
210-
{ \(closureArguments?.captureList) (\(expressionContextName): inout Testing.__ExpectationContext) in
211-
\(prefixCodeBlockItems)\(returnKeyword)\(argumentExpr)
216+
// Replace any dollar identifiers we find.
217+
let closureArguments = rewriteClosureArguments(in: codeBlockItems)
218+
if let closureArguments {
219+
codeBlockItems = closureArguments.rewrittenNode.cast(CodeBlockItemListSyntax.self)
212220
}
213-
"""
214221

215-
checkArguments.append(Argument(expression: argumentExpr))
222+
// Enclose the code block in the final closure.
223+
let closureExpr = ClosureExprSyntax(
224+
signature: ClosureSignatureSyntax(
225+
capture: closureArguments?.captureList,
226+
parameterClause: .parameterClause(
227+
ClosureParameterClauseSyntax(
228+
parameters: ClosureParameterListSyntax {
229+
ClosureParameterSyntax(
230+
firstName: expressionContextName,
231+
colon: .colonToken().with(\.trailingTrivia, .space),
232+
type: TypeSyntax(
233+
AttributedTypeSyntax(
234+
specifiers: [
235+
.init(
236+
SimpleTypeSpecifierSyntax(specifier: .keyword(.inout))
237+
.with(\.trailingTrivia, .space)
238+
)
239+
],
240+
baseType: MemberTypeSyntax(
241+
baseType: IdentifierTypeSyntax(name: .identifier("Testing")),
242+
name: .identifier("__ExpectationContext")
243+
)
244+
)
245+
)
246+
)
247+
}
248+
)
249+
),
250+
inKeyword: .keyword(.in)
251+
.with(\.leadingTrivia, .space)
252+
.with(\.trailingTrivia, .newline)
253+
),
254+
statements: codeBlockItems
255+
)
256+
checkArguments.append(Argument(expression: closureExpr))
216257

217258
// Sort the rewritten nodes. This isn't strictly necessary for
218259
// correctness but it does make the produced code more consistent.

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

+30-23
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,16 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
188188
///
189189
/// - Parameters:
190190
/// - node: The node to rewrite.
191+
/// - functionName: If not `nil`, the name of the function to call (as a
192+
/// member function of the expression context.)
193+
/// - additionalArguments: Any additional arguments to pass to the function.
191194
///
192195
/// - Returns: A rewritten copy of `node` that calls into the expression
193196
/// context when it is evaluated at runtime.
194197
///
195198
/// This function is equivalent to `_rewrite(node, originalWas: node)`.
196-
private func _rewrite<E>(_ node: E) -> ExprSyntax where E: ExprSyntaxProtocol {
197-
_rewrite(node, originalWas: node)
199+
private func _rewrite<E>(_ node: E, calling functionName: TokenSyntax? = nil, passing additionalArguments: [Argument] = []) -> ExprSyntax where E: ExprSyntaxProtocol {
200+
_rewrite(node, originalWas: node, calling: functionName, passing: additionalArguments)
198201
}
199202

200203
/// Whether or not the parent node of the given node is capable of containing
@@ -355,13 +358,11 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
355358
return _rewrite(
356359
ClosureExprSyntax {
357360
InfixOperatorExprSyntax(
358-
leftOperand: DeclReferenceExprSyntax(
359-
baseName: .dollarIdentifier("$0")
360-
).with(\.trailingTrivia, .space),
361+
leftOperand: DeclReferenceExprSyntax(baseName: .dollarIdentifier("$0"))
362+
.with(\.trailingTrivia, .space),
361363
operator: BinaryOperatorExprSyntax(text: op),
362-
rightOperand: DeclReferenceExprSyntax(
363-
baseName: .dollarIdentifier("$1")
364-
).with(\.leadingTrivia, .space)
364+
rightOperand: DeclReferenceExprSyntax(baseName: .dollarIdentifier("$1"))
365+
.with(\.leadingTrivia, .space)
365366
)
366367
},
367368
originalWas: node,
@@ -391,11 +392,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
391392

392393
let teardownItem = CodeBlockItemSyntax(
393394
item: .expr(
394-
_rewrite(
395-
node.expression,
396-
originalWas: node,
397-
calling: .identifier("__inoutAfter")
398-
)
395+
_rewrite(node.expression, calling: .identifier("__inoutAfter"))
399396
)
400397
)
401398
teardownItems.append(teardownItem)
@@ -571,9 +568,7 @@ func insertCalls(
571568
item: .stmt(
572569
StmtSyntax(
573570
DeferStmtSyntax {
574-
for teardownItem in contextInserter.teardownItems {
575-
teardownItem
576-
}
571+
contextInserter.teardownItems
577572
}
578573
)
579574
)
@@ -667,6 +662,24 @@ private final class _DollarIdentifierReplacer: SyntaxRewriter {
667662
/// The dollar identifier tokens that have been rewritten.
668663
var dollarIdentifierTokens = Set<TokenSyntax>()
669664

665+
/// The node to treat as the root node when expanding expressions.
666+
var effectiveRootNode: Syntax
667+
668+
init(rootedAt effectiveRootNode: Syntax) {
669+
self.effectiveRootNode = effectiveRootNode
670+
}
671+
672+
override func visitAny(_ node: Syntax) -> Syntax? {
673+
// Do not recurse into closure expressions (except the root node) because
674+
// they will have their own argument/capture lists that won't conflict with
675+
// the enclosing scope's.
676+
if node.is(ClosureExprSyntax.self) && node != effectiveRootNode {
677+
return Syntax(node)
678+
}
679+
680+
return nil
681+
}
682+
670683
override func visit(_ node: TokenSyntax) -> TokenSyntax {
671684
if case let .dollarIdentifier(id) = node.tokenKind, id.dropFirst().allSatisfy(\.isWholeNumber) {
672685
// This dollar identifier is numeric, so it's a closure argument.
@@ -676,12 +689,6 @@ private final class _DollarIdentifierReplacer: SyntaxRewriter {
676689

677690
return node
678691
}
679-
680-
override func visit(_ node: ClosureExprSyntax) -> ExprSyntax {
681-
// Do not recurse into closure expressions because they will have their own
682-
// argument lists that won't conflict with the enclosing scope's.
683-
return ExprSyntax(node)
684-
}
685692
}
686693

687694
/// Rewrite any implicit closure arguments (dollar identifiers such as `$0`) in
@@ -694,7 +701,7 @@ private final class _DollarIdentifierReplacer: SyntaxRewriter {
694701
/// can be used to transform the original dollar identifiers to their
695702
/// rewritten counterparts in a nested closure invocation.
696703
func rewriteClosureArguments(in node: some SyntaxProtocol) -> (rewrittenNode: Syntax, captureList: ClosureCaptureClauseSyntax)? {
697-
let replacer = _DollarIdentifierReplacer()
704+
let replacer = _DollarIdentifierReplacer(rootedAt: Syntax(node))
698705
let result = replacer.rewrite(node)
699706
if replacer.dollarIdentifierTokens.isEmpty {
700707
return nil

Tests/SubexpressionShowcase/SubexpressionShowcase.swift

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ struct T {
3030
static func j(_ d: Double) -> Bool { false }
3131
}
3232

33-
3433
@Test func runSubexpressionShowcase() async {
3534
await withKnownIssue {
3635
try await subexpressionShowcase()

Tests/TestingMacrosTests/ConditionMacroTests.swift

+6-6
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ struct ConditionMacroTests {
6161
##"#expect(a.b(c, d: e))"##:
6262
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in __ec(__ec(a.self, "6").b(__ec(c, "700"), d: __ec(e, "12100")), "") }, sourceCode: ["": "a.b(c, d: e)", "6": "a", "700": "c", "12100": "e"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
6363
##"#expect(a.b(&c))"##:
64-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "700") } return __ec(__ec(a.self, "6").b(&c), "") }, sourceCode: ["": "a.b(&c)", "6": "a", "700": "&c"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
64+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "1700") } return __ec(__ec(a.self, "6").b(&c), "") }, sourceCode: ["": "a.b(&c)", "6": "a", "1700": "c"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
6565
##"#expect(a.b(&c, &d.e))"##:
66-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "700") __ec.__inoutAfter(d.e, "18100") } return __ec(__ec(a.self, "6").b(&c, &d.e), "") }, sourceCode: ["": "a.b(&c, &d.e)", "6": "a", "700": "&c", "18100": "&d.e"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
66+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "1700") __ec.__inoutAfter(d.e, "58100") } return __ec(__ec(a.self, "6").b(&c, &d.e), "") }, sourceCode: ["": "a.b(&c, &d.e)", "6": "a", "1700": "c", "58100": "d.e"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
6767
##"#expect(a.b(&c, d))"##:
68-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "700") } return __ec(__ec(a.self, "6").b(&c, __ec(d, "18100")), "") }, sourceCode: ["": "a.b(&c, d)", "6": "a", "700": "&c", "18100": "d"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
68+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "1700") } return __ec(__ec(a.self, "6").b(&c, __ec(d, "18100")), "") }, sourceCode: ["": "a.b(&c, d)", "6": "a", "1700": "c", "18100": "d"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
6969
##"#expect(a.b(try c()))"##:
7070
##"try Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(__ec(__ec(a.self, "6").b(try __ec(c(), "1700")), "")) }, sourceCode: ["": "a.b(try c())", "6": "a", "1700": "c()"], comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##,
7171
##"#expect(a?.b(c))"##:
@@ -142,11 +142,11 @@ struct ConditionMacroTests {
142142
##"#require(a.b(c, d: e))"##:
143143
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(__ec(__ec(a.self, "6").b(__ec(c, "700"), d: __ec(e, "12100")), "")) }, sourceCode: ["": "a.b(c, d: e)", "6": "a", "700": "c", "12100": "e"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
144144
##"#require(a.b(&c))"##:
145-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "700") } return try Testing.__requiringTry(__ec(__ec(a.self, "6").b(&c), "")) }, sourceCode: ["": "a.b(&c)", "6": "a", "700": "&c"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
145+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "1700") } return try Testing.__requiringTry(__ec(__ec(a.self, "6").b(&c), "")) }, sourceCode: ["": "a.b(&c)", "6": "a", "1700": "c"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
146146
##"#require(a.b(&c, &d.e))"##:
147-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "700") __ec.__inoutAfter(d.e, "18100") } return try Testing.__requiringTry(__ec(__ec(a.self, "6").b(&c, &d.e), "")) }, sourceCode: ["": "a.b(&c, &d.e)", "6": "a", "700": "&c", "18100": "&d.e"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
147+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "1700") __ec.__inoutAfter(d.e, "58100") } return try Testing.__requiringTry(__ec(__ec(a.self, "6").b(&c, &d.e), "")) }, sourceCode: ["": "a.b(&c, &d.e)", "6": "a", "1700": "c", "58100": "d.e"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
148148
##"#require(a.b(&c, d))"##:
149-
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "700") } return try Testing.__requiringTry(__ec(__ec(a.self, "6").b(&c, __ec(d, "18100")), "")) }, sourceCode: ["": "a.b(&c, d)", "6": "a", "700": "&c", "18100": "d"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
149+
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in defer { __ec.__inoutAfter(c, "1700") } return try Testing.__requiringTry(__ec(__ec(a.self, "6").b(&c, __ec(d, "18100")), "")) }, sourceCode: ["": "a.b(&c, d)", "6": "a", "1700": "c", "18100": "d"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
150150
##"#require(a.b(try c()))"##:
151151
##"Testing.__checkCondition({ (__ec: inout Testing.__ExpectationContext) in try Testing.__requiringTry(__ec(__ec(a.self, "6").b(try __ec(c(), "1700")), "")) }, sourceCode: ["": "a.b(try c())", "6": "a", "1700": "c()"], comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##,
152152
##"#require(a?.b(c))"##:

0 commit comments

Comments
 (0)