Skip to content

Commit 11cab4c

Browse files
authored
Add contains(id:) method to AbstractModelRegistry and ModelFactory (#233)
* Add contains(id:) method * Remove duplicate code
1 parent c8164a2 commit 11cab4c

File tree

4 files changed

+34
-14
lines changed

4 files changed

+34
-14
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,6 @@ public class LLMModelFactory: ModelFactory {
250250
/// registry of model id to configuration, e.g. `mlx-community/Llama-3.2-3B-Instruct-4bit`
251251
public let modelRegistry: AbstractModelRegistry
252252

253-
public func configuration(id: String) -> ModelConfiguration {
254-
modelRegistry.configuration(id: id)
255-
}
256-
257253
public func _load(
258254
hub: HubApi, configuration: ModelConfiguration,
259255
progressHandler: @Sendable @escaping (Progress) -> Void

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,7 @@ public struct ModelContext {
4040

4141
public protocol ModelFactory: Sendable {
4242

43-
/// Resolve a model identifier, e.g. "mlx-community/Llama-3.2-3B-Instruct-4bit", into
44-
/// a ``ModelConfiguration``.
45-
///
46-
/// This will either create a new (mostly unconfigured) ``ModelConfiguration`` or
47-
/// return a registered instance that matches the id.
48-
func configuration(id: String) -> ModelConfiguration
43+
var modelRegistry: AbstractModelRegistry { get }
4944

5045
func _load(
5146
hub: HubApi, configuration: ModelConfiguration,
@@ -56,6 +51,28 @@ public protocol ModelFactory: Sendable {
5651
hub: HubApi, configuration: ModelConfiguration,
5752
progressHandler: @Sendable @escaping (Progress) -> Void
5853
) async throws -> ModelContainer
54+
55+
}
56+
57+
extension ModelFactory {
58+
59+
/// Resolve a model identifier, e.g. "mlx-community/Llama-3.2-3B-Instruct-4bit", into
60+
/// a ``ModelConfiguration``.
61+
///
62+
/// This will either create a new (mostly unconfigured) ``ModelConfiguration`` or
63+
/// return a registered instance that matches the id.
64+
///
65+
/// - Note: If the id doesn't exists in the configuration, this will return a new instance of it.
66+
/// If you want to check if the configuration in model registry, you should use ``contains(id:)``.
67+
public func configuration(id: String) -> ModelConfiguration {
68+
modelRegistry.configuration(id: id)
69+
}
70+
71+
/// Returns true if ``modelRegistry`` contains a model with the id. Otherwise, false.
72+
public func contains(id: String) -> Bool {
73+
modelRegistry.contains(id: id)
74+
}
75+
5976
}
6077

6178
extension ModelFactory {

Libraries/MLXLMCommon/Registries/AbstractModelRegistry.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ open class AbstractModelRegistry: @unchecked Sendable {
2525
}
2626
}
2727

28+
/// Returns configuration from ``modelRegistry``.
29+
///
30+
/// - Note: If the id doesn't exists in the configuration, this will return a new instance of it.
31+
/// If you want to check if the configuration in model registry, you should use ``contains(id:)``.
2832
public func configuration(id: String) -> ModelConfiguration {
2933
lock.withLock {
3034
if let c = registry[id] {
@@ -35,6 +39,13 @@ open class AbstractModelRegistry: @unchecked Sendable {
3539
}
3640
}
3741

42+
/// Returns true if the registry contains a model with the id. Otherwise, false.
43+
public func contains(id: String) -> Bool {
44+
lock.withLock {
45+
registry[id] != nil
46+
}
47+
}
48+
3849
public var models: some Collection<ModelConfiguration> & Sendable {
3950
lock.withLock {
4051
return registry.values

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,6 @@ public class VLMModelFactory: ModelFactory {
159159
/// registry of model id to configuration, e.g. `mlx-community/paligemma-3b-mix-448-8bit`
160160
public let modelRegistry: AbstractModelRegistry
161161

162-
public func configuration(id: String) -> ModelConfiguration {
163-
modelRegistry.configuration(id: id)
164-
}
165-
166162
public func _load(
167163
hub: HubApi, configuration: ModelConfiguration,
168164
progressHandler: @Sendable @escaping (Progress) -> Void

0 commit comments

Comments
 (0)