Skip to content

Commit

Permalink
Merge pull request #138 from squareup/skorulis/multiple-assemblies
Browse files Browse the repository at this point in the history
Allow parsing multiple assemblies in a single file
  • Loading branch information
skorulis-ap authored Apr 9, 2024
2 parents e187b22 + 4950c3e commit 76acfbd
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 83 deletions.
75 changes: 48 additions & 27 deletions Sources/KnitCodeGen/AssemblyParser.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ public struct AssemblyParser {
at paths: [String],
externalTestingAssemblies: [String]
) throws -> ConfigurationSet {
let configs = try paths.compactMap { path in
let configs = try paths.flatMap { path in
return try parse(
path: path,
defaultTargetResolver: defaultTargetResolver,
useTargetResolver: useTargetResolver
)
}
let additionalConfigs = try externalTestingAssemblies.compactMap { path in
let additionalConfigs = try externalTestingAssemblies.flatMap { path in
return try parse(
path: path,
defaultTargetResolver: defaultTargetResolver,
Expand All @@ -44,7 +44,7 @@ public struct AssemblyParser {
return ConfigurationSet(assemblies: configs, externalTestingAssemblies: additionalConfigs)
}

private func parse(path: String, defaultTargetResolver: String, useTargetResolver: Bool) throws -> Configuration? {
private func parse(path: String, defaultTargetResolver: String, useTargetResolver: Bool) throws -> [Configuration] {
let url = URL(fileURLWithPath: path, isDirectory: false)
var errorsToPrint = [Error]()

Expand All @@ -55,7 +55,7 @@ public struct AssemblyParser {
throw AssemblyParsingError.fileReadError(error, path: path)
}
let syntaxTree = Parser.parse(source: source)
let configuration = try parseSyntaxTree(
let configurations = try parseSyntaxTree(
syntaxTree,
path: path,
errorsToPrint: &errorsToPrint
Expand All @@ -64,55 +64,76 @@ public struct AssemblyParser {
if errorsToPrint.count > 0 {
throw AssemblyParsingError.parsingError
}
return configuration
return configurations
}

func parseSyntaxTree(
_ syntaxTree: SyntaxProtocol,
path: String? = nil,
errorsToPrint: inout [Error]
) throws -> Configuration? {
var extractedModuleName: String?
if let path {
extractedModuleName = nameExtractor.extractModuleName(path: path)
}
) throws -> [Configuration] {

let assemblyFileVisitor = AssemblyFileVisitor()
assemblyFileVisitor.walk(syntaxTree)

if assemblyFileVisitor.directives.accessLevel == .ignore { return nil }
errorsToPrint.append(contentsOf: assemblyFileVisitor.assemblyErrors)
errorsToPrint.append(contentsOf: assemblyFileVisitor.registrationErrors)

// If the file doesn't contain assemblies in a valid format, throw to let the developer know
if assemblyFileVisitor.classDeclVisitors.isEmpty && !assemblyFileVisitor.hasIgnoredConfigurations {
throw AssemblyParsingError.noAssembliesFound
}


guard let assemblyName = assemblyFileVisitor.assemblyName else {
throw AssemblyParsingError.missingAssemblyName
let configurations = try assemblyFileVisitor.classDeclVisitors.compactMap { classVisitor in
return try makeConfiguration(
classDeclVisitor: classVisitor,
assemblyFileVisitor: assemblyFileVisitor,
path: path
)
}
let moduleName = assemblyFileVisitor.directives.moduleName ?? extractedModuleName ?? assemblyFileVisitor.moduleName
guard let moduleName else {
throw AssemblyParsingError.missingModuleName
let moduleNames = Set(configurations.map { $0.moduleName })
if moduleNames.count > 1 {
throw AssemblyParsingError.moduleNameMismatch
}

guard let assemblyType = assemblyFileVisitor.assemblyType else {
throw AssemblyParsingError.missingAssemblyType
}
return configurations
}

errorsToPrint.append(contentsOf: assemblyFileVisitor.assemblyErrors)
errorsToPrint.append(contentsOf: assemblyFileVisitor.registrationErrors)
private func makeConfiguration(
classDeclVisitor: ClassDeclVisitor,
assemblyFileVisitor: AssemblyFileVisitor,
path: String?
) throws -> Configuration? {
if classDeclVisitor.directives.accessLevel == .ignore {
return nil
}
var extractedModuleName: String?
if let path {
extractedModuleName = nameExtractor.extractModuleName(path: path)
}
let moduleName = classDeclVisitor.directives.moduleName ?? extractedModuleName ?? classDeclVisitor.moduleName

let targetResolver: String
if useTargetResolver {
targetResolver = assemblyFileVisitor.targetResolver ?? defaultTargetResolver
targetResolver = classDeclVisitor.targetResolver ?? defaultTargetResolver
} else {
targetResolver = defaultTargetResolver
}

guard let assemblyType = classDeclVisitor.assemblyType else {
throw AssemblyParsingError.missingAssemblyType
}

return Configuration(
assemblyName: assemblyName,
assemblyName: classDeclVisitor.assemblyName,
moduleName: moduleName,
directives: assemblyFileVisitor.directives,
directives: classDeclVisitor.directives,
assemblyType: assemblyType,
registrations: assemblyFileVisitor.registrations,
registrationsIntoCollections: assemblyFileVisitor.registrationsIntoCollections,
registrations: classDeclVisitor.registrations,
registrationsIntoCollections: classDeclVisitor.registrationsIntoCollections,
imports: assemblyFileVisitor.imports,
implements: assemblyFileVisitor.implements,
implements: classDeclVisitor.implements,
targetResolver: targetResolver
)
}
Expand Down
90 changes: 41 additions & 49 deletions Sources/KnitCodeGen/AssemblyParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,17 @@ class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
/// The imports that were found in the tree.
private(set) var imports = [ModuleImport]()

private(set) var assemblyName: String?
private(set) var classDeclVisitors: [ClassDeclVisitor] = []

private(set) var moduleName: String?

private(set) var assemblyType: Configuration.AssemblyType?

private(set) var directives: KnitDirectives = .empty

private var classDeclVisitor: ClassDeclVisitor?
private(set) var hasIgnoredConfigurations: Bool = false

private(set) var assemblyErrors: [Error] = []

/// For any imports parsed, this #if condition should be applied when it is used
var currentIfConfigCondition: IfConfigVisitorCondition?

var registrations: [Registration] {
return classDeclVisitor?.registrations ?? []
}

var implements: [String] {
return classDeclVisitor?.implements ?? []
}

var registrationsIntoCollections: [RegistrationIntoCollection] {
return classDeclVisitor?.registrationsIntoCollections ?? []
}

var registrationErrors: [Error] {
return classDeclVisitor?.registrationErrors ?? []
}

var targetResolver: String? {
return classDeclVisitor?.targetResolver
return classDeclVisitors.flatMap { $0.registrationErrors }
}

init() {
Expand Down Expand Up @@ -77,25 +55,24 @@ class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
}

private func visitAssemblyType(_ node: NamedDeclSyntax, _ inheritance: InheritanceClauseSyntax?) -> SyntaxVisitorContinueKind {
guard classDeclVisitor == nil else {
// Only the first class declaration should be visited
return .skipChildren
}
var directives: KnitDirectives = .empty
do {
directives = try KnitDirectives.parse(leadingTrivia: node.leadingTrivia)

if directives.accessLevel == .ignore {
// Entire assembly is marked as ignore, stop parsing
self.hasIgnoredConfigurations = true
return .skipChildren
}

} catch {
assemblyErrors.append(error)
}

let names = node.namesForAssembly
assemblyName = names?.0
moduleName = node.namesForAssembly?.1
guard let assemblyName = names?.0,
let moduleName = node.namesForAssembly?.1 else {
return .skipChildren
}

let inheritedTypes = inheritance?.inheritedTypes.compactMap {
if let identifier = $0.type.as(IdentifierTypeSyntax.self) {
Expand All @@ -106,20 +83,31 @@ class AssemblyFileVisitor: SyntaxVisitor, IfConfigVisitor {
return nil
}
}
self.assemblyType = inheritedTypes?
let assemblyType = inheritedTypes?
.first { $0.hasSuffix(Configuration.AssemblyType.baseAssembly.rawValue) }
.flatMap { Configuration.AssemblyType(rawValue: $0) }
classDeclVisitor = ClassDeclVisitor(viewMode: .fixedUp, directives: directives, assemblyType: assemblyType)
classDeclVisitor?.walk(node)

let classDeclVisitor = ClassDeclVisitor(
viewMode: .fixedUp,
directives: directives,
assemblyName: assemblyName,
moduleName: moduleName,
assemblyType: assemblyType
)
classDeclVisitor.walk(node)
self.classDeclVisitors.append(classDeclVisitor)

return .skipChildren
}

}

private class ClassDeclVisitor: SyntaxVisitor, IfConfigVisitor {
class ClassDeclVisitor: SyntaxVisitor, IfConfigVisitor {

private let directives: KnitDirectives
private let assemblyType: Configuration.AssemblyType?
let directives: KnitDirectives
let assemblyType: Configuration.AssemblyType?
let assemblyName: String
let moduleName: String

/// The registrations that were found in the tree.
private(set) var registrations = [Registration]()
Expand All @@ -136,8 +124,16 @@ private class ClassDeclVisitor: SyntaxVisitor, IfConfigVisitor {
/// For any registrations parsed, this #if condition should be applied when it is used
var currentIfConfigCondition: IfConfigVisitorCondition?

init(viewMode: SyntaxTreeViewMode, directives: KnitDirectives, assemblyType: Configuration.AssemblyType?) {
init(
viewMode: SyntaxTreeViewMode,
directives: KnitDirectives,
assemblyName: String,
moduleName: String,
assemblyType: Configuration.AssemblyType?
) {
self.directives = directives
self.assemblyName = assemblyName
self.moduleName = moduleName
self.assemblyType = assemblyType
super.init(viewMode: viewMode)
}
Expand Down Expand Up @@ -245,10 +241,10 @@ extension NamedDeclSyntax {

enum AssemblyParsingError: Error {
case fileReadError(Error, path: String)
case missingAssemblyName
case missingModuleName
case missingAssemblyType
case parsingError
case noAssembliesFound
case moduleNameMismatch
}

extension AssemblyParsingError: LocalizedError {
Expand All @@ -260,20 +256,16 @@ extension AssemblyParsingError: LocalizedError {
Error reading file: \(error.localizedDescription)
File path: \(path)
"""

case .missingAssemblyName:
return "Cannot generate unit test source file without an assembly name. " +
"Is your Assembly file setup correctly?"
case .missingModuleName:
return "Cannot generate unit test source file without a module name. " +
"Is your Assembly file setup correctly?"
case .parsingError:
return "There were one or more errors parsing the assembly file"
case .missingAssemblyType:
return "Assembly files must inherit from an *Assembly type"
case .noAssembliesFound:
return "The given file did not contain any valid assemblies"
case .moduleNameMismatch:
return "Assemblies in a single file have different modules"
}
}

}

enum ImplementsParsingError: LocalizedError, SyntaxError {
Expand Down
Loading

0 comments on commit 76acfbd

Please sign in to comment.