Skip to content

Commit

Permalink
Add support for macro conditions to named enums
Browse files Browse the repository at this point in the history
  • Loading branch information
skorulis-ap committed Feb 20, 2025
1 parent a59aebc commit 406686c
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 33 deletions.
17 changes: 2 additions & 15 deletions Sources/KnitCodeGen/HeaderSourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,11 @@ public enum HeaderSourceFile {
leadingTrivia: TriviaProvider.headerTrivia,
statementsBuilder: {
for moduleImport in imports {
importDecl(moduleImport: moduleImport)
moduleImport.decl
.maybeWithCondition(ifConfigCondition: moduleImport.ifConfigCondition)
}
},
trailingTrivia: trivia
)
}

private static func importDecl(moduleImport: ModuleImport) -> DeclSyntaxProtocol {
// Wrap the output in an #if where needed
guard let ifConfigCondition = moduleImport.ifConfigCondition else {
return moduleImport.decl
}
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(moduleImport.decl))])
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
return IfConfigDeclSyntax(clauses: [clause])
}
}
12 changes: 12 additions & 0 deletions Sources/KnitCodeGen/NamedRegistrationGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

import Foundation
import SwiftSyntax

/// Collection of named registrations for a single service.
struct NamedRegistrationGroup {
Expand All @@ -17,6 +18,7 @@ struct NamedRegistrationGroup {
return dict.map { key, value in
return NamedRegistrationGroup(service: key, registrations: value)
}
.sorted(by: { $0.service < $1.service})
}

var accessLevel: AccessLevel {
Expand All @@ -31,4 +33,14 @@ struct NamedRegistrationGroup {
return "\(sanitizedType)_ResolutionKey"
}

// The if config condition wrapping the entire group
var ifConfigCondition: ExprSyntax? {
guard let firstCondition = registrations.first?.ifConfigCondition else {
return nil
}
// Only wrap the entire group if all conditions within the group match
let allMatch = registrations.allSatisfy { $0.ifConfigCondition?.description == firstCondition.description }
return allMatch ? firstCondition : nil
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//
// Copyright © Block, Inc. All rights reserved.
//

import Foundation
import SwiftSyntax

extension NamedRegistrationGroup {
// Generate the enum for this group of named registrations
func enumSourceCode(assemblyName: String) throws -> DeclSyntaxProtocol {
let modifier = accessLevel == .public ? "public " : ""
let enumSyntax = try EnumDeclSyntax("\(raw: modifier)enum \(raw: enumName): String, CaseIterable") {
for reg in registrations {
("case \(raw: reg.name!)" as DeclSyntax).maybeWithCondition(
ifConfigCondition: ifConfigCondition == nil ? reg.ifConfigCondition : nil
)
}
}
return enumSyntax.maybeWithCondition(ifConfigCondition: ifConfigCondition)
}
}
17 changes: 17 additions & 0 deletions Sources/KnitCodeGen/SwiftSyntax+Helpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,20 @@ extension VariableDeclSyntax {
)
}
}

extension DeclSyntaxProtocol {

// Wrap the declaration in an #if where needed
func maybeWithCondition(ifConfigCondition: ExprSyntax?) -> DeclSyntaxProtocol {
guard let ifConfigCondition else {
return self
}
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(self))])
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
return IfConfigDeclSyntax(clauses: [clause])
}
}
27 changes: 9 additions & 18 deletions Sources/KnitCodeGen/TypeSafetySourceFile.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@ public enum TypeSafetySourceFile {
""") {

for registration in unnamedRegistrations {
try makeResolver(registration: registration, getterAlias: registration.getterAlias)
try makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
getterAlias: registration.getterAlias
)
}
for namedGroup in namedGroups {
let firstGetterAlias = namedGroup.registrations[0].getterAlias
try makeResolver(
registration: namedGroup.registrations[0],
ifConfigCondition: namedGroup.ifConfigCondition,
enumName: "\(config.assemblyName).\(namedGroup.enumName)",
getterAlias: firstGetterAlias
)
Expand All @@ -50,6 +55,7 @@ public enum TypeSafetySourceFile {
/// Create the type safe resolver function for this registration
static func makeResolver(
registration: Registration,
ifConfigCondition: ExprSyntax?,
enumName: String? = nil,
getterAlias: String? = nil
) throws -> DeclSyntaxProtocol {
Expand Down Expand Up @@ -77,17 +83,7 @@ public enum TypeSafetySourceFile {
"knitUnwrap(resolve(\(raw: usages)), callsiteFile: file, callsiteFunction: function, callsiteLine: line)"
}

// Wrap the output in an #if where needed
guard let ifConfigCondition = registration.ifConfigCondition else {
return function
}
let codeBlock = CodeBlockItemListSyntax([.init(item: .init(function))])
let clause = IfConfigClauseSyntax(
poundKeyword: .poundIfToken(),
condition: ifConfigCondition,
elements: .statements(codeBlock)
)
return IfConfigDeclSyntax(clauses: [clause])
return function.maybeWithCondition(ifConfigCondition: ifConfigCondition)
}

private static func argumentString(registration: Registration) -> (input: String?, usage: String?) {
Expand All @@ -107,12 +103,7 @@ public enum TypeSafetySourceFile {
) throws -> ExtensionDeclSyntax {
try ExtensionDeclSyntax("extension \(raw: assemblyName)") {
for namedGroup in namedGroups {
let modifier = namedGroup.accessLevel == .public ? "public " : ""
try EnumDeclSyntax("\(raw: modifier)enum \(raw: namedGroup.enumName): String, CaseIterable") {
for test in namedGroup.registrations {
"case \(raw: test.name!)" as DeclSyntax
}
}
try namedGroup.enumSourceCode(assemblyName: assemblyName)
}
}
}
Expand Down
49 changes: 49 additions & 0 deletions Tests/KnitCodeGenTests/NamedRegistrationGroupSourceTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//
// Copyright © Block, Inc. All rights reserved.
//

@testable import KnitCodeGen
import XCTest
import SwiftSyntax

final class NamedRegistrationGroupSourceTests: XCTestCase {

func testResolutionKeyWithInternalMacro() throws {
let registration1 = Registration(service: "ServiceB", name: "name", ifConfigCondition: ExprSyntax("DEBUG"))
let registration2 = Registration(service: "ServiceB", name: "name2")
let registration3 = Registration(service: "ServiceB", name: "name3", ifConfigCondition: ExprSyntax("DEBUG"))
let group = NamedRegistrationGroup.make(from: [registration1, registration2, registration3])[0]
let result = try group.enumSourceCode(assemblyName: "Assembly")

let expected = """
enum ServiceB_ResolutionKey: String, CaseIterable {
#if DEBUG
case name
#endif
case name2
#if DEBUG
case name3
#endif
}
"""

XCTAssertEqual(expected, result.formatted().description)

}

func testResolutionKeyWithExternalMacro() throws {
let registration1 = Registration(service: "ServiceB", name: "name2", ifConfigCondition: ExprSyntax("DEBUG"))
let group = NamedRegistrationGroup.make(from: [registration1])[0]
let result = try group.enumSourceCode(assemblyName: "Assembly")

let expected = """
#if DEBUG
enum ServiceB_ResolutionKey: String, CaseIterable {
case name2
}
#endif
"""

XCTAssertEqual(expected, result.formatted().description)
}
}
16 changes: 16 additions & 0 deletions Tests/KnitCodeGenTests/NamedRegistrationGroupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

@testable import KnitCodeGen
import Foundation
import SwiftSyntax
import XCTest

final class NamedRegistrationGroupTests: XCTestCase {
Expand Down Expand Up @@ -43,4 +44,19 @@ final class NamedRegistrationGroupTests: XCTestCase {
)
}

func testIfConfigCondition() throws {
let registration1 = Registration(service: "ServiceA", name: "name1", ifConfigCondition: ExprSyntax("DEBUG"))
let registration2 = Registration(service: "ServiceA", name: "name2")
let registration3 = Registration(service: "ServiceA", name: "name3", ifConfigCondition: ExprSyntax("RELEASE"))

let namedGroup1 = NamedRegistrationGroup.make(from: [registration1])[0]
XCTAssertEqual(namedGroup1.ifConfigCondition?.description, ExprSyntax("DEBUG").description)

let namedGroup2 = NamedRegistrationGroup.make(from: [registration1, registration2])[0]
XCTAssertNil(namedGroup2.ifConfigCondition)

let namedGroup3 = NamedRegistrationGroup.make(from: [registration1, registration3])[0]
XCTAssertNil(namedGroup3.ifConfigCondition)
}

}
59 changes: 59 additions & 0 deletions Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -92,6 +93,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -107,6 +109,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -122,6 +125,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: "MyAssembly.A_ResolutionKey"
).formatted().description,
"""
Expand All @@ -137,6 +141,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -153,6 +158,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand All @@ -170,6 +176,7 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(
try TypeSafetySourceFile.makeResolver(
registration: registration,
ifConfigCondition: registration.ifConfigCondition,
enumName: nil
).formatted().description,
"""
Expand Down Expand Up @@ -375,4 +382,56 @@ final class TypeSafetySourceFileTests: XCTestCase {
XCTAssertEqual(expected, result.formatted().description)
}

func testResolutionKeyWithMacros() throws {
let registration1 = Registration(service: "ServiceA", name: "name")
let registration2 = Registration(service: "ServiceA", name: "name2", ifConfigCondition: ExprSyntax("DEBUG"))
let registration3 = Registration(service: "ServiceB", name: "name2", ifConfigCondition: ExprSyntax("RELEASE"))
let result = try TypeSafetySourceFile.make(
from: Configuration(
assemblyName: "ModuleAssembly",
moduleName: "Module",
registrations: [registration2, registration1, registration3],
targetResolver: "Resolver"
)
)

let expected = """
/// Generated from ``ModuleAssembly``
extension Resolver {
func serviceA(name: ModuleAssembly.ServiceA_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceA {
knitUnwrap(resolve(ServiceA.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
#if RELEASE
func serviceB(name: ModuleAssembly.ServiceB_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceB {
knitUnwrap(resolve(ServiceB.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line)
}
#endif
}
extension ModuleAssembly {
enum ServiceA_ResolutionKey: String, CaseIterable {
#if DEBUG
case name2
#endif
case name
}
#if RELEASE
enum ServiceB_ResolutionKey: String, CaseIterable {
case name2
}
#endif
}
extension ModuleAssembly {
public static var _assemblyFlags: [ModuleAssemblyFlags] {
[]
}
public static func _autoInstantiate() -> (any ModuleAssembly)? {
nil
}
}
"""

XCTAssertEqual(expected, result.formatted().description)

}

}

0 comments on commit 406686c

Please sign in to comment.