diff --git a/Sources/KnitCodeGen/HeaderSourceFile.swift b/Sources/KnitCodeGen/HeaderSourceFile.swift index 4e854b0..f884802 100644 --- a/Sources/KnitCodeGen/HeaderSourceFile.swift +++ b/Sources/KnitCodeGen/HeaderSourceFile.swift @@ -23,24 +23,11 @@ public enum HeaderSourceFile { leadingTrivia: TriviaProvider.headerTrivia, statementsBuilder: { for moduleImport in imports { - importDecl(moduleImport: moduleImport) + moduleImport.decl + .maybeWithCondition(ifConfigCondition: moduleImport.ifConfigCondition) } }, trailingTrivia: trivia ) } - - private static func importDecl(moduleImport: ModuleImport) -> DeclSyntaxProtocol { - // Wrap the output in an #if where needed - guard let ifConfigCondition = moduleImport.ifConfigCondition else { - return moduleImport.decl - } - let codeBlock = CodeBlockItemListSyntax([.init(item: .init(moduleImport.decl))]) - let clause = IfConfigClauseSyntax( - poundKeyword: .poundIfToken(), - condition: ifConfigCondition, - elements: .statements(codeBlock) - ) - return IfConfigDeclSyntax(clauses: [clause]) - } } diff --git a/Sources/KnitCodeGen/NamedRegistrationGroup.swift b/Sources/KnitCodeGen/NamedRegistrationGroup.swift index bc113ca..2ebda18 100644 --- a/Sources/KnitCodeGen/NamedRegistrationGroup.swift +++ b/Sources/KnitCodeGen/NamedRegistrationGroup.swift @@ -3,6 +3,7 @@ // import Foundation +import SwiftSyntax /// Collection of named registrations for a single service. struct NamedRegistrationGroup { @@ -17,6 +18,7 @@ struct NamedRegistrationGroup { return dict.map { key, value in return NamedRegistrationGroup(service: key, registrations: value) } + .sorted(by: { $0.service < $1.service}) } var accessLevel: AccessLevel { @@ -31,4 +33,14 @@ struct NamedRegistrationGroup { return "\(sanitizedType)_ResolutionKey" } + // The if config condition wrapping the entire group + var ifConfigCondition: ExprSyntax? { + guard let firstCondition = registrations.first?.ifConfigCondition else { + return nil + } + // Only wrap the entire group if all conditions within the group match + let allMatch = registrations.allSatisfy { $0.ifConfigCondition?.description == firstCondition.description } + return allMatch ? firstCondition : nil + } + } diff --git a/Sources/KnitCodeGen/SourceGen/NamedRegistrationGroup+SourceCode.swift b/Sources/KnitCodeGen/SourceGen/NamedRegistrationGroup+SourceCode.swift new file mode 100644 index 0000000..c9ceabc --- /dev/null +++ b/Sources/KnitCodeGen/SourceGen/NamedRegistrationGroup+SourceCode.swift @@ -0,0 +1,21 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +import Foundation +import SwiftSyntax + +extension NamedRegistrationGroup { + // Generate the enum for this group of named registrations + func enumSourceCode(assemblyName: String) throws -> DeclSyntaxProtocol { + let modifier = accessLevel == .public ? "public " : "" + let enumSyntax = try EnumDeclSyntax("\(raw: modifier)enum \(raw: enumName): String, CaseIterable") { + for reg in registrations { + ("case \(raw: reg.name!)" as DeclSyntax).maybeWithCondition( + ifConfigCondition: ifConfigCondition == nil ? reg.ifConfigCondition : nil + ) + } + } + return enumSyntax.maybeWithCondition(ifConfigCondition: ifConfigCondition) + } +} diff --git a/Sources/KnitCodeGen/SwiftSyntax+Helpers.swift b/Sources/KnitCodeGen/SwiftSyntax+Helpers.swift index 3fac7b3..e27a7be 100644 --- a/Sources/KnitCodeGen/SwiftSyntax+Helpers.swift +++ b/Sources/KnitCodeGen/SwiftSyntax+Helpers.swift @@ -52,3 +52,20 @@ extension VariableDeclSyntax { ) } } + +extension DeclSyntaxProtocol { + + // Wrap the declaration in an #if where needed + func maybeWithCondition(ifConfigCondition: ExprSyntax?) -> DeclSyntaxProtocol { + guard let ifConfigCondition else { + return self + } + let codeBlock = CodeBlockItemListSyntax([.init(item: .init(self))]) + let clause = IfConfigClauseSyntax( + poundKeyword: .poundIfToken(), + condition: ifConfigCondition, + elements: .statements(codeBlock) + ) + return IfConfigDeclSyntax(clauses: [clause]) + } +} diff --git a/Sources/KnitCodeGen/TypeSafetySourceFile.swift b/Sources/KnitCodeGen/TypeSafetySourceFile.swift index 902e417..9dccf3d 100644 --- a/Sources/KnitCodeGen/TypeSafetySourceFile.swift +++ b/Sources/KnitCodeGen/TypeSafetySourceFile.swift @@ -24,12 +24,17 @@ public enum TypeSafetySourceFile { """) { for registration in unnamedRegistrations { - try makeResolver(registration: registration, getterAlias: registration.getterAlias) + try makeResolver( + registration: registration, + ifConfigCondition: registration.ifConfigCondition, + getterAlias: registration.getterAlias + ) } for namedGroup in namedGroups { let firstGetterAlias = namedGroup.registrations[0].getterAlias try makeResolver( registration: namedGroup.registrations[0], + ifConfigCondition: namedGroup.ifConfigCondition, enumName: "\(config.assemblyName).\(namedGroup.enumName)", getterAlias: firstGetterAlias ) @@ -50,6 +55,7 @@ public enum TypeSafetySourceFile { /// Create the type safe resolver function for this registration static func makeResolver( registration: Registration, + ifConfigCondition: ExprSyntax?, enumName: String? = nil, getterAlias: String? = nil ) throws -> DeclSyntaxProtocol { @@ -77,17 +83,7 @@ public enum TypeSafetySourceFile { "knitUnwrap(resolve(\(raw: usages)), callsiteFile: file, callsiteFunction: function, callsiteLine: line)" } - // Wrap the output in an #if where needed - guard let ifConfigCondition = registration.ifConfigCondition else { - return function - } - let codeBlock = CodeBlockItemListSyntax([.init(item: .init(function))]) - let clause = IfConfigClauseSyntax( - poundKeyword: .poundIfToken(), - condition: ifConfigCondition, - elements: .statements(codeBlock) - ) - return IfConfigDeclSyntax(clauses: [clause]) + return function.maybeWithCondition(ifConfigCondition: ifConfigCondition) } private static func argumentString(registration: Registration) -> (input: String?, usage: String?) { @@ -107,12 +103,7 @@ public enum TypeSafetySourceFile { ) throws -> ExtensionDeclSyntax { try ExtensionDeclSyntax("extension \(raw: assemblyName)") { for namedGroup in namedGroups { - let modifier = namedGroup.accessLevel == .public ? "public " : "" - try EnumDeclSyntax("\(raw: modifier)enum \(raw: namedGroup.enumName): String, CaseIterable") { - for test in namedGroup.registrations { - "case \(raw: test.name!)" as DeclSyntax - } - } + try namedGroup.enumSourceCode(assemblyName: assemblyName) } } } diff --git a/Tests/KnitCodeGenTests/NamedRegistrationGroupSourceTests.swift b/Tests/KnitCodeGenTests/NamedRegistrationGroupSourceTests.swift new file mode 100644 index 0000000..d23b447 --- /dev/null +++ b/Tests/KnitCodeGenTests/NamedRegistrationGroupSourceTests.swift @@ -0,0 +1,49 @@ +// +// Copyright © Block, Inc. All rights reserved. +// + +@testable import KnitCodeGen +import XCTest +import SwiftSyntax + +final class NamedRegistrationGroupSourceTests: XCTestCase { + + func testResolutionKeyWithInternalMacro() throws { + let registration1 = Registration(service: "ServiceB", name: "name", ifConfigCondition: ExprSyntax("DEBUG")) + let registration2 = Registration(service: "ServiceB", name: "name2") + let registration3 = Registration(service: "ServiceB", name: "name3", ifConfigCondition: ExprSyntax("DEBUG")) + let group = NamedRegistrationGroup.make(from: [registration1, registration2, registration3])[0] + let result = try group.enumSourceCode(assemblyName: "Assembly") + + let expected = """ + enum ServiceB_ResolutionKey: String, CaseIterable { + #if DEBUG + case name + #endif + case name2 + #if DEBUG + case name3 + #endif + } + """ + + XCTAssertEqual(expected, result.formatted().description) + + } + + func testResolutionKeyWithExternalMacro() throws { + let registration1 = Registration(service: "ServiceB", name: "name2", ifConfigCondition: ExprSyntax("DEBUG")) + let group = NamedRegistrationGroup.make(from: [registration1])[0] + let result = try group.enumSourceCode(assemblyName: "Assembly") + + let expected = """ + #if DEBUG + enum ServiceB_ResolutionKey: String, CaseIterable { + case name2 + } + #endif + """ + + XCTAssertEqual(expected, result.formatted().description) + } +} diff --git a/Tests/KnitCodeGenTests/NamedRegistrationGroupTests.swift b/Tests/KnitCodeGenTests/NamedRegistrationGroupTests.swift index 99aa619..c6b6fc0 100644 --- a/Tests/KnitCodeGenTests/NamedRegistrationGroupTests.swift +++ b/Tests/KnitCodeGenTests/NamedRegistrationGroupTests.swift @@ -4,6 +4,7 @@ @testable import KnitCodeGen import Foundation +import SwiftSyntax import XCTest final class NamedRegistrationGroupTests: XCTestCase { @@ -43,4 +44,19 @@ final class NamedRegistrationGroupTests: XCTestCase { ) } + func testIfConfigCondition() throws { + let registration1 = Registration(service: "ServiceA", name: "name1", ifConfigCondition: ExprSyntax("DEBUG")) + let registration2 = Registration(service: "ServiceA", name: "name2") + let registration3 = Registration(service: "ServiceA", name: "name3", ifConfigCondition: ExprSyntax("RELEASE")) + + let namedGroup1 = NamedRegistrationGroup.make(from: [registration1])[0] + XCTAssertEqual(namedGroup1.ifConfigCondition?.description, ExprSyntax("DEBUG").description) + + let namedGroup2 = NamedRegistrationGroup.make(from: [registration1, registration2])[0] + XCTAssertNil(namedGroup2.ifConfigCondition) + + let namedGroup3 = NamedRegistrationGroup.make(from: [registration1, registration3])[0] + XCTAssertNil(namedGroup3.ifConfigCondition) + } + } diff --git a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift index 5839fab..6f920ca 100644 --- a/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift +++ b/Tests/KnitCodeGenTests/TypeSafetySourceFileTests.swift @@ -77,6 +77,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: nil ).formatted().description, """ @@ -92,6 +93,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: nil ).formatted().description, """ @@ -107,6 +109,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: nil ).formatted().description, """ @@ -122,6 +125,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: "MyAssembly.A_ResolutionKey" ).formatted().description, """ @@ -137,6 +141,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: nil ).formatted().description, """ @@ -153,6 +158,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: nil ).formatted().description, """ @@ -170,6 +176,7 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual( try TypeSafetySourceFile.makeResolver( registration: registration, + ifConfigCondition: registration.ifConfigCondition, enumName: nil ).formatted().description, """ @@ -375,4 +382,56 @@ final class TypeSafetySourceFileTests: XCTestCase { XCTAssertEqual(expected, result.formatted().description) } + func testResolutionKeyWithMacros() throws { + let registration1 = Registration(service: "ServiceA", name: "name") + let registration2 = Registration(service: "ServiceA", name: "name2", ifConfigCondition: ExprSyntax("DEBUG")) + let registration3 = Registration(service: "ServiceB", name: "name2", ifConfigCondition: ExprSyntax("RELEASE")) + let result = try TypeSafetySourceFile.make( + from: Configuration( + assemblyName: "ModuleAssembly", + moduleName: "Module", + registrations: [registration2, registration1, registration3], + targetResolver: "Resolver" + ) + ) + + let expected = """ + /// Generated from ``ModuleAssembly`` + extension Resolver { + func serviceA(name: ModuleAssembly.ServiceA_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceA { + knitUnwrap(resolve(ServiceA.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + } + #if RELEASE + func serviceB(name: ModuleAssembly.ServiceB_ResolutionKey, file: StaticString = #fileID, function: StaticString = #function, line: UInt = #line) -> ServiceB { + knitUnwrap(resolve(ServiceB.self, name: name.rawValue), callsiteFile: file, callsiteFunction: function, callsiteLine: line) + } + #endif + } + extension ModuleAssembly { + enum ServiceA_ResolutionKey: String, CaseIterable { + #if DEBUG + case name2 + #endif + case name + } + #if RELEASE + enum ServiceB_ResolutionKey: String, CaseIterable { + case name2 + } + #endif + } + extension ModuleAssembly { + public static var _assemblyFlags: [ModuleAssemblyFlags] { + [] + } + public static func _autoInstantiate() -> (any ModuleAssembly)? { + nil + } + } + """ + + XCTAssertEqual(expected, result.formatted().description) + + } + }