1+ from .provider import ProviderFactory , ProviderNames
2+
3+
4+ class Client :
5+ def __init__ (self , provider_configs : dict = {}):
6+ """
7+ Initialize the client with provider configurations.
8+ Use the ProviderFactory to create provider instances.
9+ """
10+ self .providers = {}
11+ self .provider_configs = provider_configs
12+ for provider_key , config in provider_configs .items ():
13+ # Check if the provider key is a valid ProviderNames enum
14+ if not isinstance (provider_key , ProviderNames ):
15+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
16+ # Store the value of the enum in the providers dictionary
17+ self .providers [provider_key .value ] = ProviderFactory .create_provider (provider_key , config )
18+
19+ self ._chat = None
20+
21+ def configure (self , provider_configs : dict = None ):
22+ """
23+ Configure the client with provider configurations.
24+ """
25+ if provider_configs is None :
26+ return
27+
28+ self .provider_configs .update (provider_configs )
29+
30+ for provider_key , config in self .provider_configs .items ():
31+ if not isinstance (provider_key , ProviderNames ):
32+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
33+ self .providers [provider_key .value ] = ProviderFactory .create_provider (provider_key , config )
34+
35+ @property
36+ def chat (self ):
37+ """Return the chat API interface."""
38+ if not self ._chat :
39+ self ._chat = Chat (self )
40+ return self ._chat
41+
42+
43+ class Chat :
44+ def __init__ (self , client : 'Client' ):
45+ self .client = client
46+
47+ @property
48+ def completions (self ):
49+ """Return the completions interface."""
50+ return Completions (self .client )
51+
52+
53+ class Completions :
54+ def __init__ (self , client : 'Client' ):
55+ self .client = client
56+
57+ def create (self , model : str , messages : list , ** kwargs ):
58+ """
59+ Create chat completion based on the model, messages, and any extra arguments.
60+ """
61+ # Check that correct format is used
62+ if ':' not in model :
63+ raise ValueError (f"Invalid model format. Expected 'provider:model', got '{ model } '" )
64+
65+ # Extract the provider key from the model identifier, e.g., "aws-bedrock:model-name"
66+ provider_key , model_name = model .split (":" , 1 )
67+
68+ if provider_key not in ProviderNames ._value2member_map_ :
69+ raise ValueError (f"Provider { provider_key } is not a valid ProviderNames enum" )
70+
71+ if provider_key not in self .client .providers :
72+ config = {}
73+ if provider_key in self .client .provider_configs :
74+ config = self .client .provider_configs [provider_key ]
75+ self .client .providers [provider_key ] = ProviderFactory .create_provider (ProviderNames (provider_key ), config )
76+
77+ provider = self .client .providers .get (provider_key )
78+ if not provider :
79+ raise ValueError (f"Could not load provider for { provider_key } ." )
80+
81+ # Delegate the chat completion to the correct provider's implementation
82+ # Any additional arguments will be passed to the provider's implementation.
83+ # Eg: max_tokens, temperature, etc.
84+ return provider .chat_completions_create (model_name , messages , ** kwargs )
0 commit comments