Skip to content

Commit 4b493e3

Browse files
committed
Borrow/consume some args in the macro expansion
1 parent 47f1b1a commit 4b493e3

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

Diff for: Sources/Testing/Expectations/ExpectationContext.swift

+25-14
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ extension __ExpectationContext {
167167
///
168168
/// This function helps overloads of `callAsFunction(_:_:)` disambiguate
169169
/// themselves and avoid accidental recursion.
170-
@usableFromInline func captureValue<T>(_ value: T, _ id: __ExpressionID) -> T {
170+
@usableFromInline func captureValue<T>(_ value: consuming T, _ id: __ExpressionID) -> T {
171+
let value = copy value
171172
runtimeValues[id] = { Expression.Value(reflecting: value) }
172173
return value
173174
}
@@ -185,7 +186,7 @@ extension __ExpectationContext {
185186
/// - Warning: This function is used to implement the `#expect()` and
186187
/// `#require()` macros. Do not call it directly.
187188
@_disfavoredOverload
188-
@inlinable public func callAsFunction<T>(_ value: T, _ id: __ExpressionID) -> T {
189+
@inlinable public func callAsFunction<T>(_ value: consuming T, _ id: __ExpressionID) -> T {
189190
captureValue(value, id)
190191
}
191192

@@ -275,12 +276,14 @@ extension __ExpectationContext {
275276
@inlinable public func __cmp<T, U>(
276277
_ op: (T, U) throws -> Bool,
277278
_ opID: __ExpressionID,
278-
_ lhs: T,
279+
_ lhs: borrowing T,
279280
_ lhsID: __ExpressionID,
280-
_ rhs: U,
281+
_ rhs: borrowing U,
281282
_ rhsID: __ExpressionID
282283
) rethrows -> Bool {
283-
try captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
284+
let lhs = copy lhs
285+
let rhs = copy rhs
286+
return try captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
284287
}
285288

286289
/// Compare two bidirectional collections using `==` or `!=`.
@@ -293,11 +296,13 @@ extension __ExpectationContext {
293296
public func __cmp<C>(
294297
_ op: (C, C) -> Bool,
295298
_ opID: __ExpressionID,
296-
_ lhs: C,
299+
_ lhs: borrowing C,
297300
_ lhsID: __ExpressionID,
298-
_ rhs: C,
301+
_ rhs: borrowing C,
299302
_ rhsID: __ExpressionID
300303
) -> Bool where C: BidirectionalCollection, C.Element: Equatable {
304+
let lhs = copy lhs
305+
let rhs = copy rhs
301306
let result = captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
302307

303308
if !result {
@@ -318,12 +323,14 @@ extension __ExpectationContext {
318323
@inlinable public func __cmp<R>(
319324
_ op: (R, R) -> Bool,
320325
_ opID: __ExpressionID,
321-
_ lhs: R,
326+
_ lhs: borrowing R,
322327
_ lhsID: __ExpressionID,
323-
_ rhs: R,
328+
_ rhs: borrowing R,
324329
_ rhsID: __ExpressionID
325330
) -> Bool where R: RangeExpression & BidirectionalCollection, R.Element: Equatable {
326-
captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
331+
let lhs = copy lhs
332+
let rhs = copy rhs
333+
return captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
327334
}
328335

329336
/// Compare two strings using `==` or `!=`.
@@ -337,11 +344,13 @@ extension __ExpectationContext {
337344
public func __cmp<S>(
338345
_ op: (S, S) -> Bool,
339346
_ opID: __ExpressionID,
340-
_ lhs: S,
347+
_ lhs: borrowing S,
341348
_ lhsID: __ExpressionID,
342-
_ rhs: S,
349+
_ rhs: borrowing S,
343350
_ rhsID: __ExpressionID
344351
) -> Bool where S: StringProtocol {
352+
let lhs = copy lhs
353+
let rhs = copy rhs
345354
let result = captureValue(op(captureValue(lhs, lhsID), captureValue(rhs, rhsID)), opID)
346355

347356
if !result {
@@ -392,7 +401,8 @@ extension __ExpectationContext {
392401
///
393402
/// - Warning: This function is used to implement the `#expect()` and
394403
/// `#require()` macros. Do not call it directly.
395-
@inlinable public func __as<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> U? {
404+
@inlinable public func __as<T, U>(_ value: consuming T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> U? {
405+
let value = copy value
396406
let result = captureValue(value, valueID) as? U
397407

398408
if result == nil {
@@ -421,7 +431,8 @@ extension __ExpectationContext {
421431
///
422432
/// - Warning: This function is used to implement the `#expect()` and
423433
/// `#require()` macros. Do not call it directly.
424-
@inlinable public func __is<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> Bool {
434+
@inlinable public func __is<T, U>(_ value: borrowing T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> Bool {
435+
let value = copy value
425436
let result = captureValue(value, valueID) is U
426437

427438
if !result {

0 commit comments

Comments
 (0)