-
Notifications
You must be signed in to change notification settings - Fork 901
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into rcp/ToolCalling-P1
- Loading branch information
Showing
38 changed files
with
1,510 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,12 @@ __pycache__/ | |
env/ | ||
.env | ||
.google-adc | ||
|
||
# Testing | ||
.coverage | ||
|
||
# pyenv | ||
.python-version | ||
|
||
.DS_Store | ||
**/.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
import cohere | ||
|
||
from aisuite.framework import ChatCompletionResponse | ||
from aisuite.provider import Provider | ||
|
||
|
||
class CohereProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the Cohere provider with the given configuration. | ||
Pass the entire configuration dictionary to the Cohere client constructor. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
config.setdefault("api_key", os.getenv("CO_API_KEY")) | ||
if not config["api_key"]: | ||
raise ValueError( | ||
" API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." | ||
) | ||
self.client = cohere.ClientV2(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
response = self.client.chat( | ||
model=model, | ||
messages=messages, | ||
**kwargs # Pass any additional arguments to the Cohere API | ||
) | ||
|
||
return self.normalize_response(response) | ||
|
||
def normalize_response(self, response): | ||
"""Normalize the reponse from Cohere API to match OpenAI's response format.""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response.message.content[ | ||
0 | ||
].text | ||
return normalized_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import openai | ||
import os | ||
from aisuite.provider import Provider, LLMError | ||
|
||
|
||
class DeepseekProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the DeepSeek provider with the given configuration. | ||
Pass the entire configuration dictionary to the OpenAI client constructor. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
config.setdefault("api_key", os.getenv("DEEPSEEK_API_KEY")) | ||
if not config["api_key"]: | ||
raise ValueError( | ||
"DeepSeek API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." | ||
) | ||
config["base_url"] = "https://api.deepseek.com" | ||
|
||
# NOTE: We could choose to remove above lines for api_key since OpenAI will automatically | ||
# infer certain values from the environment variables. | ||
# Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID. Except for OPEN_AI_BASE_URL which has to be the deepseek url | ||
|
||
# Pass the entire config to the OpenAI client constructor | ||
self.client = openai.OpenAI(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
# Any exception raised by OpenAI will be returned to the caller. | ||
# Maybe we should catch them and raise a custom LLMError. | ||
return self.client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
**kwargs # Pass any additional arguments to the OpenAI API | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import os | ||
from aisuite.provider import Provider | ||
from openai import Client | ||
|
||
|
||
BASE_URL = "https://api.studio.nebius.ai/v1" | ||
|
||
|
||
class NebiusProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the Nebius AI Studio provider with the given configuration. | ||
Pass the entire configuration dictionary to the OpenAI client constructor. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
config.setdefault("api_key", os.getenv("NEBIUS_API_KEY")) | ||
if not config["api_key"]: | ||
raise ValueError( | ||
"Nebius AI Studio API key is missing. Please provide it in the config or set the NEBIUS_API_KEY environment variable. You can get your API key at https://studio.nebius.ai/settings/api-keys" | ||
) | ||
|
||
config["base_url"] = BASE_URL | ||
# Pass the entire config to the OpenAI client constructor | ||
self.client = Client(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
return self.client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
**kwargs # Pass any additional arguments to the Nebius API | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import os | ||
from aisuite.provider import Provider | ||
from openai import OpenAI | ||
|
||
|
||
class SambanovaProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the SambaNova provider with the given configuration. | ||
Pass the entire configuration dictionary to the OpenAI client constructor. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
config.setdefault("api_key", os.getenv("SAMBANOVA_API_KEY")) | ||
if not config["api_key"]: | ||
raise ValueError( | ||
"Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." | ||
) | ||
|
||
config["base_url"] = "https://api.sambanova.ai/v1/" | ||
# Pass the entire config to the OpenAI client constructor | ||
self.client = OpenAI(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
# Any exception raised by Sambanova will be returned to the caller. | ||
# Maybe we should catch them and raise a custom LLMError. | ||
return self.client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
**kwargs # Pass any additional arguments to the Sambanova API | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from aisuite.provider import Provider | ||
import os | ||
from ibm_watsonx_ai import Credentials | ||
from ibm_watsonx_ai.foundation_models import ModelInference | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
|
||
class WatsonxProvider(Provider): | ||
def __init__(self, **config): | ||
self.service_url = config.get("service_url") or os.getenv("WATSONX_SERVICE_URL") | ||
self.api_key = config.get("api_key") or os.getenv("WATSONX_API_KEY") | ||
self.project_id = config.get("project_id") or os.getenv("WATSONX_PROJECT_ID") | ||
|
||
if not self.service_url or not self.api_key or not self.project_id: | ||
raise EnvironmentError( | ||
"Missing one or more required WatsonX environment variables: " | ||
"WATSONX_SERVICE_URL, WATSONX_API_KEY, WATSONX_PROJECT_ID. " | ||
"Please refer to the setup guide: /guides/watsonx.md." | ||
) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
model = ModelInference( | ||
model_id=model, | ||
credentials=Credentials( | ||
api_key=self.api_key, | ||
url=self.service_url, | ||
), | ||
project_id=self.project_id, | ||
) | ||
|
||
res = model.chat(messages=messages, params=kwargs) | ||
return self.normalize_response(res) | ||
|
||
def normalize_response(self, response): | ||
openai_response = ChatCompletionResponse() | ||
openai_response.choices[0].message.content = response["choices"][0]["message"][ | ||
"content" | ||
] | ||
return openai_response |
Oops, something went wrong.