1
+ from .provider import ProviderFactory , ProviderNames
2
+
3
+
4
+ class Client :
5
+ def __init__ (self , provider_configs : dict = {}):
6
+ """
7
+ Initialize the client with provider configurations.
8
+ Use the ProviderFactory to create provider instances.
9
+ """
10
+ self .providers = {}
11
+ self .provider_configs = provider_configs
12
+ for provider_key , config in provider_configs .items ():
13
+ # Check if the provider key is a valid ProviderNames enum
14
+ if not isinstance (provider_key , ProviderNames ):
15
+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
16
+ # Store the value of the enum in the providers dictionary
17
+ self .providers [provider_key .value ] = ProviderFactory .create_provider (provider_key , config )
18
+
19
+ self ._chat = None
20
+
21
+ def configure (self , provider_configs : dict = None ):
22
+ """
23
+ Configure the client with provider configurations.
24
+ """
25
+ if provider_configs is None :
26
+ return
27
+
28
+ self .provider_configs .update (provider_configs )
29
+
30
+ for provider_key , config in self .provider_configs .items ():
31
+ if not isinstance (provider_key , ProviderNames ):
32
+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
33
+ self .providers [provider_key .value ] = ProviderFactory .create_provider (provider_key , config )
34
+
35
+ @property
36
+ def chat (self ):
37
+ """Return the chat API interface."""
38
+ if not self ._chat :
39
+ self ._chat = Chat (self )
40
+ return self ._chat
41
+
42
+
43
+ class Chat :
44
+ def __init__ (self , client : 'Client' ):
45
+ self .client = client
46
+
47
+ @property
48
+ def completions (self ):
49
+ """Return the completions interface."""
50
+ return Completions (self .client )
51
+
52
+
53
+ class Completions :
54
+ def __init__ (self , client : 'Client' ):
55
+ self .client = client
56
+
57
+ def create (self , model : str , messages : list , ** kwargs ):
58
+ """
59
+ Create chat completion based on the model, messages, and any extra arguments.
60
+ """
61
+ # Check that correct format is used
62
+ if ':' not in model :
63
+ raise ValueError (f"Invalid model format. Expected 'provider:model', got '{ model } '" )
64
+
65
+ # Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
66
+ provider_key , model_name = model .split (":" , 1 )
67
+
68
+ if provider_key not in ProviderNames ._value2member_map_ :
69
+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
70
+
71
+ if provider_key not in self .client .providers :
72
+ config = {}
73
+ if provider_key in self .client .provider_configs :
74
+ config = self .client .provider_configs [provider_key ]
75
+ self .client .providers [provider_key ] = ProviderFactory .create_provider (ProviderNames (provider_key ), config )
76
+
77
+ provider = self .client .providers .get (provider_key )
78
+ if not provider :
79
+ raise ValueError (f"Could not load provider for { provider_key } ." )
80
+
81
+ # Delegate the chat completion to the correct provider's implementation
82
+ # Any additional arguments will be passed to the provider's implementation.
83
+ # Eg: max_tokens, temperature, etc.
84
+ return provider .chat_completions_create (model_name , messages , ** kwargs )
0 commit comments