Skip to content

Commit

Permalink
Add support for macro conditions to named enums
Browse files Browse the repository at this point in the history
  • Loading branch information
skorulis-ap committed Feb 12, 2025
1 parent cf86daa commit 78088be
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 20 deletions.
11 changes: 11 additions & 0 deletions Sources/KnitCodeGen/NamedRegistrationGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

import Foundation
import SwiftSyntax

/// Collection of named registrations for a single service.
struct NamedRegistrationGroup {
Expand All @@ -17,6 +18,7 @@ struct NamedRegistrationGroup {
return dict.map { key, value in
return NamedRegistrationGroup(service: key, registrations: value)
}
.sorted(by: { $0.service < $1.service})
}

var accessLevel: AccessLevel {
Expand All @@ -31,4 +33,13 @@ struct NamedRegistrationGroup {
return "\(sanitizedType)_ResolutionKey"
}

// The if config condition wrapping the entire group
var ifConfigCondition: ExprSyntax? {
guard var firstCondition = registrations.first?.ifConfigCondition else {
return nil
}
let allMatch = registrations.allSatisfy { $0.ifConfigCondition == firstCondition }
return allMatch ? firstCondition : nil
}

}
4 changes: 3 additions & 1 deletion Sources/KnitCodeGen/Registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ public struct Registration: Equatable, Codable, Sendable {
concurrencyModifier: String? = nil,
getterConfig: Set<GetterConfig> = GetterConfig.default,
functionName: FunctionName = .register,
spi: String? = nil
spi: String? = nil,
ifConfigCondition: ExprSyntax? = nil
) {
self.service = service
self.name = name
Expand All @@ -47,6 +48,7 @@ public struct Registration: Equatable, Codable, Sendable {
self.getterConfig = getterConfig
self.functionName = functionName
self.spi = spi
self.ifConfigCondition = ifConfigCondition
}

/// This registration is forwarded to another service entry.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//
// Copyright © Block, Inc. All rights reserved.
//

import Foundation
import SwiftSyntax

extension NamedRegistrationGroup {
func enumSourceCode(assemblyName: String) throws -> DeclSyntaxProtocol {
let modifier = accessLevel == .public ? "public " : ""
let enumSyntax = try EnumDeclSyntax("\(raw: modifier)enum \(raw: enumName): String, CaseIterable") {
for reg in registrations {
("case \(raw: reg.name!)" as DeclSyntax).maybeWithCondition(
ifConfigCondition: ifConfigCondition == nil ? reg.ifConfigCondition : nil
)
}
}
return enumSyntax.maybeWithCondition(ifConfigCondition: ifConfigCondition)
}
}
17 changes: 17 additions & 0 deletions Sources/KnitCodeGen/SwiftSyntax+Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,20 @@ extension VariableDeclSyntax {
)
}
}

extension DeclSyntaxProtocol {

// Wrap the output in an #if where needed
func maybeWithCondition(ifConfigCondition: ExprSyntax?) -> DeclSyntaxProtocol {
guard let ifConfigCondition else {
return self
}
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(self))])
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
return IfConfigDeclSyntax(clauses: [clause])
}
}
33 changes: 14 additions & 19 deletions Sources/KnitCodeGen/TypeSafetySourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,25 @@ public enum TypeSafetySourceFile {

for registration in unnamedRegistrations {
if registration.getterConfig.contains(.callAsFunction) {
try makeResolver(registration: registration, getterType: .callAsFunction)
try makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
getterType: .callAsFunction
)
}
if let namedGetter = registration.namedGetterConfig {
try makeResolver(registration: registration, getterType: namedGetter)
try makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
getterType: namedGetter
)
}
}
for namedGroup in namedGroups {
let firstGetterConfig = namedGroup.registrations[0].getterConfig.first ?? .callAsFunction
try makeResolver(
registration: namedGroup.registrations[0],
ifConfigCondition: namedGroup.ifConfigCondition,
enumName: "\(config.assemblyName).\(namedGroup.enumName)",
getterType: firstGetterConfig
)
Expand All @@ -55,6 +64,7 @@ public enum TypeSafetySourceFile {
/// Create the type safe resolver function for this registration
static func makeResolver(
registration: Registration,
ifConfigCondition: ExprSyntax?,
enumName: String? = nil,
getterType: GetterConfig = .callAsFunction
) throws -> DeclSyntaxProtocol {
Expand Down Expand Up @@ -88,17 +98,7 @@ public enum TypeSafetySourceFile {
"knitUnwrap(resolve(\(raw: usages)), callsiteFile: file, callsiteFunction: function, callsiteLine: line)"
}

// Wrap the output in an #if where needed
guard let ifConfigCondition = registration.ifConfigCondition else {
return function
}
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(function))])
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
return IfConfigDeclSyntax(clauses: [clause])
return function.maybeWithCondition(ifConfigCondition: ifConfigCondition)
}

private static func argumentString(registration: Registration) -> (input: String?, usage: String?) {
Expand All @@ -118,12 +118,7 @@ public enum TypeSafetySourceFile {
) throws -> ExtensionDeclSyntax {
try ExtensionDeclSyntax("extension \(raw: assemblyName)") {
for namedGroup in namedGroups {
let modifier = namedGroup.accessLevel == .public ? "public " : ""
try EnumDeclSyntax("\(raw: modifier)enum \(raw: namedGroup.enumName): String, CaseIterable") {
for test in namedGroup.registrations {
"case \(raw: test.name!)" as DeclSyntax
}
}
try namedGroup.enumSourceCode(assemblyName: assemblyName)
}
}
}
Expand Down
49 changes: 49 additions & 0 deletions Tests/KnitCodeGenTests/NamedRegistrationGroupSourceTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//
// Copyright © Block, Inc. All rights reserved.
//

@testable import KnitCodeGen
import XCTest
import SwiftSyntax

final class NamedRegistrationGroupSourceTests: XCTestCase {

func testResolutionKeyWithInternalMacro() throws {
let registration1 = Registration(service: "ServiceB", name: "name", ifConfigCondition: ExprSyntax("DEBUG"))
let registration2 = Registration(service: "ServiceB", name: "name2")
let registration3 = Registration(service: "ServiceB", name: "name3", ifConfigCondition: ExprSyntax("DEBUG"))
let group = NamedRegistrationGroup.make(from: [registration1, registration2, registration3])[0]
let result = try group.enumSourceCode(assemblyName: "Assembly")

let expected = """
enum ServiceB_ResolutionKey: String, CaseIterable {
#if DEBUG
case name
#endif
case name2
#if DEBUG
case name3
#endif
}
"""

XCTAssertEqual(expected, result.formatted().description)

}

func testResolutionKeyWithExternalMacro() throws {
let registration1 = Registration(service: "ServiceB", name: "name2", ifConfigCondition: ExprSyntax("DEBUG"))
let group = NamedRegistrationGroup.make(from: [registration1])[0]
let result = try group.enumSourceCode(assemblyName: "Assembly")

let expected = """
#if DEBUG
enum ServiceB_ResolutionKey: String, CaseIterable {
case name2
}
#endif
"""

XCTAssertEqual(expected, result.formatted().description)
}
}
16 changes: 16 additions & 0 deletions Tests/KnitCodeGenTests/NamedRegistrationGroupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

@testable import KnitCodeGen
import Foundation
import SwiftSyntax
import XCTest

final class NamedRegistrationGroupTests: XCTestCase {
Expand Down Expand Up @@ -43,4 +44,19 @@ final class NamedRegistrationGroupTests: XCTestCase {
)
}

func testIfConfigCondition() throws {
let registration1 = Registration(service: "ServiceA", name: "name1", ifConfigCondition: ExprSyntax("DEBUG"))
let registration2 = Registration(service: "ServiceA", name: "name2")
let registration3 = Registration(service: "ServiceA", name: "name3", ifConfigCondition: ExprSyntax("RELEASE"))

let namedGroup1 = NamedRegistrationGroup.make(from: [registration1])[0]
XCTAssertEqual(namedGroup1.ifConfigCondition?.description, ExprSyntax("DEBUG").description)

let namedGroup2 = NamedRegistrationGroup.make(from: [registration1, registration2])[0]
XCTAssertNil(namedGroup2.ifConfigCondition)

let namedGroup3 = NamedRegistrationGroup.make(from: [registration1, registration3])[0]
XCTAssertNil(namedGroup2.ifConfigCondition)
}

}
59 changes: 59 additions & 0 deletions Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -95,6 +96,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -110,6 +112,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -125,6 +128,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: "MyAssembly.A_ResolutionKey"
).formatted().description,
"""
Expand All @@ -140,6 +144,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -156,6 +161,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -173,6 +179,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand Down Expand Up @@ -378,4 +385,56 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(expected, result.formatted().description)
}

func testResolutionKeyWithMacros() throws {
let registration1 = Registration(service: "ServiceA", name: "name")
let registration2 = Registration(service: "ServiceA", name: "name2", ifConfigCondition: ExprSyntax("DEBUG"))
let registration3 = Registration(service: "ServiceB", name: "name2", ifConfigCondition: ExprSyntax("RELEASE"))
let result = try TypeSafetySourceFile.make(
from: Configuration(
assemblyName: "ModuleAssembly",
moduleName: "Module",
registrations: [registration2, registration1, registration3],
targetResolver: "Resolver"
)
)

let expected = """
/// Generated from ``ModuleAssembly``
extension Resolver {
func serviceA(name: ModuleAssembly.ServiceA_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceA {
knitUnwrap(resolve(ServiceA.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
#if RELEASE
func serviceB(name: ModuleAssembly.ServiceB_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceB {
knitUnwrap(resolve(ServiceB.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
#endif
}
extension ModuleAssembly {
enum ServiceA_ResolutionKey: String, CaseIterable {
#if DEBUG
case name2
#endif
case name
}
#if RELEASE
enum ServiceB_ResolutionKey: String, CaseIterable {
case name2
}
#endif
}
extension ModuleAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
"""

XCTAssertEqual(expected, result.formatted().description)

}

}

0 comments on commit 78088be

Please sign in to comment.