2
2
3
3
4
4
class Client :
5
- def __init__ (self , provider_configs : dict ):
5
+ def __init__ (self , provider_configs : dict = {} ):
6
6
"""
7
7
Initialize the client with provider configurations.
8
8
Use the ProviderFactory to create provider instances.
9
9
"""
10
10
self .providers = {}
11
+ self .provider_configs = provider_configs
11
12
for provider_key , config in provider_configs .items ():
12
13
# Check if the provider key is a valid ProviderNames enum
13
14
if not isinstance (provider_key , ProviderNames ):
@@ -17,6 +18,20 @@ def __init__(self, provider_configs: dict):
17
18
18
19
self ._chat = None
19
20
21
+ def configure (self , provider_configs : dict = None ):
22
+ """
23
+ Configure the client with provider configurations.
24
+ """
25
+ if provider_configs is None :
26
+ return
27
+
28
+ self .provider_configs .update (provider_configs )
29
+
30
+ for provider_key , config in self .provider_configs .items ():
31
+ if not isinstance (provider_key , ProviderNames ):
32
+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
33
+ self .providers [provider_key .value ] = ProviderFactory .create_provider (provider_key , config )
34
+
20
35
@property
21
36
def chat (self ):
22
37
"""Return the chat API interface."""
@@ -39,18 +54,33 @@ class Completions:
39
54
def __init__ (self , client : 'Client' ):
40
55
self .client = client
41
56
42
- def create (self , model : str , messages : list ):
43
- """Create chat completion based on the model."""
57
+ def create (self , model : str , messages : list , ** kwargs ):
58
+ """
59
+ Create chat completion based on the model, messages, and any extra arguments.
60
+ """
61
+ # Check that correct format is used
62
+ if ':' not in model :
63
+ raise ValueError (f"Invalid model format. Expected 'provider:model', got '{ model } '" )
64
+
44
65
# Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
45
66
provider_key , model_name = model .split (":" , 1 )
46
67
47
- # Use the correct provider instance created by the factory
68
+ if provider_key not in ProviderNames ._value2member_map_ :
69
+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
70
+
71
+ if provider_key not in self .client .providers :
72
+ config = {}
73
+ if provider_key in self .client .provider_configs :
74
+ config = self .client .provider_configs [provider_key ]
75
+ self .client .providers [provider_key ] = ProviderFactory .create_provider (ProviderNames (provider_key ), config )
76
+
48
77
provider = self .client .providers .get (provider_key )
49
78
if not provider :
50
- # Add the providers to the ValueError
51
- raise ValueError (f"Provider { provider_key } is not present in the client. Here are the providers: { self .client .providers } " )
79
+ raise ValueError (f"Could not load provider for { provider_key } ." )
52
80
53
81
# Delegate the chat completion to the correct provider's implementation
54
- return provider .chat_completions_create (model_name , messages )
82
+ # Any additional arguments will be passed to the provider's implementation.
83
+ # Eg: max_tokens, temperature, etc.
84
+ return provider .chat_completions_create (model_name , messages , ** kwargs )
55
85
56
86
0 commit comments