-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathmulti_fm_client.py
82 lines (66 loc) · 2.51 KB
/
multi_fm_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""MultiFMClient manages a Chat across multiple provider interfaces."""
from .chat import Chat
from ..providers import (
AnthropicInterface,
FireworksInterface,
GroqInterface,
MistralInterface,
OllamaInterface,
OpenAIInterface,
ReplicateInterface,
)
class MultiFMClient:
"""Manages multiple provider interfaces."""
_MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE = (
"Expected ':' in model identifier to specify provider:model. Got {model}."
)
_NO_FACTORY_ERROR_MESSAGE_TEMPLATE = (
"Could not find factory to create interface for provider '{provider}'."
)
def __init__(self):
"""Initialize the MultiFMClient instance.
Attributes
----------
chat (Chat): The chat session.
all_interfaces (dict): Stores interface instances by provider names.
all_factories (dict): Maps provider names to their corresponding interfaces.
"""
self.chat = Chat(self)
self.all_interfaces = {}
self.all_factories = {
"anthropic": AnthropicInterface,
"fireworks": FireworksInterface,
"groq": GroqInterface,
"mistral": MistralInterface,
"ollama": OllamaInterface,
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
}
def get_provider_interface(self, model):
"""Retrieve or create a provider interface based on a model identifier.
Args:
----
model (str): The model identifier in the format 'provider:model'.
Raises:
------
ValueError: If the model identifier does colon-separate provider and model.
Exception: If no factory is found from the supplied model.
Returns:
-------
The interface instance for the provider and the model name.
"""
if ":" not in model:
raise ValueError(
self._MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE.format(model=model)
)
model_parts = model.split(":", maxsplit=1)
provider = model_parts[0]
model_name = model_parts[1]
if provider in self.all_interfaces:
return self.all_interfaces[provider], model_name
if provider not in self.all_factories:
raise Exception(
self._NO_FACTORY_ERROR_MESSAGE_TEMPLATE.format(provider=provider)
)
self.all_interfaces[provider] = self.all_factories[provider]()
return self.all_interfaces[provider], model_name