diff --git a/.env.sample b/.env.sample index 00826f4b..cc933d9b 100644 --- a/.env.sample +++ b/.env.sample @@ -18,10 +18,21 @@ GOOGLE_REGION= GOOGLE_PROJECT_ID= # Hugging Face token -HUGGINGFACE_TOKEN= +HF_TOKEN= # Fireworks FIREWORKS_API_KEY= # Together AI TOGETHER_API_KEY= + +# WatsonX +WATSONX_SERVICE_URL= +WATSONX_API_KEY= +WATSONX_PROJECT_ID= + +# xAI +XAI_API_KEY= + +# Sambanova +SAMBANOVA_API_KEY= diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index 0093c348..d873172b 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -18,7 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install poetry - poetry install + poetry install --all-extras --with test - name: Test with pytest - run: poetry run pytest + run: poetry run pytest -m "not integration" diff --git a/.gitignore b/.gitignore index 5b651c66..a0974550 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,12 @@ __pycache__/ env/ .env .google-adc + +# Testing +.coverage + +# pyenv +.python-version + +.DS_Store +**/.DS_Store diff --git a/README.md b/README.md index bd7df95e..add8b851 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # aisuite +[![PyPI](https://img.shields.io/pypi/v/aisuite)](https://pypi.org/project/aisuite/) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) Simple, unified interface to multiple Generative AI providers. @@ -7,7 +8,7 @@ Simple, unified interface to multiple Generative AI providers. `aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future. Currently supported providers are - -OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace and Ollama. +OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova and Watsonx. To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider. ## Installation @@ -21,11 +22,13 @@ pip install aisuite ``` This installs aisuite along with anthropic's library. + ```shell pip install 'aisuite[anthropic]' ``` This installs all the provider-specific libraries + ```shell pip install 'aisuite[all]' ``` @@ -41,12 +44,14 @@ You can use tools like [`python-dotenv`](https://pypi.org/project/python-dotenv/ Here is a short example of using `aisuite` to generate chat completion responses from gpt-4o and claude-3-5-sonnet. Set the API keys. + ```shell export OPENAI_API_KEY="your-openai-api-key" export ANTHROPIC_API_KEY="your-anthropic-api-key" ``` Use the python client. + ```python import aisuite as ai client = ai.Client() @@ -67,6 +72,7 @@ for model in models: print(response.choices[0].message.content) ``` + Note that the model name in the create() call uses the format - `:`. `aisuite` will call the appropriate provider with the right parameters based on the provider value. For a list of provider values, you can look at the directory - `aisuite/providers/`. The list of supported providers are of the format - `_provider.py` in that directory. We welcome providers adding support to this library by adding an implementation file in this directory. Please see section below for how to contribute. @@ -79,9 +85,10 @@ aisuite is released under the MIT License. You are free to use, modify, and dist ## Contributing -If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md) and join our [Discord](https://discord.gg/T6Nvn8ExSb) server! +If you would like to contribute, please read our [Contributing Guide](https://github.com/andrewyng/aisuite/blob/main/CONTRIBUTING.md) and join our [Discord](https://discord.gg/T6Nvn8ExSb) server! ## Adding support for a provider + We have made easy for a provider or volunteer to add support for a new platform. ### Naming Convention for Provider Modules @@ -91,20 +98,24 @@ We follow a convention-based approach for loading providers, which relies on str - The provider's module file must be named in the format `_provider.py`. - The class inside this module must follow the format: the provider name with the first letter capitalized, followed by the suffix `Provider`. -#### Examples: +#### Examples - **Hugging Face**: The provider class should be defined as: + ```python class HuggingfaceProvider(BaseProvider) ``` + in providers/huggingface_provider.py. - **OpenAI**: The provider class should be defined as: + ```python class OpenaiProvider(BaseProvider) ``` + in providers/openai_provider.py This convention simplifies the addition of new providers and ensures consistency across provider implementations. diff --git a/aisuite/framework/message.py b/aisuite/framework/message.py index 26be291f..eaa611ef 100644 --- a/aisuite/framework/message.py +++ b/aisuite/framework/message.py @@ -1,4 +1,4 @@ -"""Interface to hold contents of api responses when they do not conform to the OpenAI style response""" +"""Interface to hold contents of api responses when they do not confirm to the OpenAI style response""" from pydantic import BaseModel from typing import Literal, Optional diff --git a/aisuite/providers/aws_provider.py b/aisuite/providers/aws_provider.py index 1aa54051..01b2526a 100644 --- a/aisuite/providers/aws_provider.py +++ b/aisuite/providers/aws_provider.py @@ -14,7 +14,7 @@ class BedrockConfig: def __init__(self, **config): self.region_name = config.get( - "region_name", os.getenv("AWS_REGION_NAME", "us-west-2") + "region_name", os.getenv("AWS_REGION", "us-west-2") ) def create_client(self): diff --git a/aisuite/providers/cohere_provider.py b/aisuite/providers/cohere_provider.py new file mode 100644 index 00000000..5886f24b --- /dev/null +++ b/aisuite/providers/cohere_provider.py @@ -0,0 +1,37 @@ +import os +import cohere + +from aisuite.framework import ChatCompletionResponse +from aisuite.provider import Provider + + +class CohereProvider(Provider): + def __init__(self, **config): + """ + Initialize the Cohere provider with the given configuration. + Pass the entire configuration dictionary to the Cohere client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("CO_API_KEY")) + if not config["api_key"]: + raise ValueError( + " API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." + ) + self.client = cohere.ClientV2(**config) + + def chat_completions_create(self, model, messages, **kwargs): + response = self.client.chat( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Cohere API + ) + + return self.normalize_response(response) + + def normalize_response(self, response): + """Normalize the reponse from Cohere API to match OpenAI's response format.""" + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response.message.content[ + 0 + ].text + return normalized_response diff --git a/aisuite/providers/deepseek_provider.py b/aisuite/providers/deepseek_provider.py new file mode 100644 index 00000000..16327c57 --- /dev/null +++ b/aisuite/providers/deepseek_provider.py @@ -0,0 +1,34 @@ +import openai +import os +from aisuite.provider import Provider, LLMError + + +class DeepseekProvider(Provider): + def __init__(self, **config): + """ + Initialize the DeepSeek provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("DEEPSEEK_API_KEY")) + if not config["api_key"]: + raise ValueError( + "DeepSeek API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + ) + config["base_url"] = "https://api.deepseek.com" + + # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically + # infer certain values from the environment variables. + # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID. Except for OPEN_AI_BASE_URL which has to be the deepseek url + + # Pass the entire config to the OpenAI client constructor + self.client = openai.OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) diff --git a/aisuite/providers/huggingface_provider.py b/aisuite/providers/huggingface_provider.py index dd9be17e..ac8af9ce 100644 --- a/aisuite/providers/huggingface_provider.py +++ b/aisuite/providers/huggingface_provider.py @@ -21,10 +21,10 @@ def __init__(self, **config): The token is fetched from the config or environment variables. """ # Ensure API key is provided either in config or via environment variable - self.token = config.get("token") or os.getenv("HUGGINGFACE_TOKEN") + self.token = config.get("token") or os.getenv("HF_TOKEN") if not self.token: raise ValueError( - "Hugging Face token is missing. Please provide it in the config or set the HUGGINGFACE_TOKEN environment variable." + "Hugging Face token is missing. Please provide it in the config or set the HF_TOKEN environment variable." ) # Initialize the InferenceClient with the specified model and timeout if provided diff --git a/aisuite/providers/nebius_provider.py b/aisuite/providers/nebius_provider.py new file mode 100644 index 00000000..c558a9ce --- /dev/null +++ b/aisuite/providers/nebius_provider.py @@ -0,0 +1,31 @@ +import os +from aisuite.provider import Provider +from openai import Client + + +BASE_URL = "https://api.studio.nebius.ai/v1" + + +class NebiusProvider(Provider): + def __init__(self, **config): + """ + Initialize the Nebius AI Studio provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("NEBIUS_API_KEY")) + if not config["api_key"]: + raise ValueError( + "Nebius AI Studio API key is missing. Please provide it in the config or set the NEBIUS_API_KEY environment variable. You can get your API key at https://studio.nebius.ai/settings/api-keys" + ) + + config["base_url"] = BASE_URL + # Pass the entire config to the OpenAI client constructor + self.client = Client(**config) + + def chat_completions_create(self, model, messages, **kwargs): + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Nebius API + ) diff --git a/aisuite/providers/sambanova_provider.py b/aisuite/providers/sambanova_provider.py new file mode 100644 index 00000000..75a97311 --- /dev/null +++ b/aisuite/providers/sambanova_provider.py @@ -0,0 +1,30 @@ +import os +from aisuite.provider import Provider +from openai import OpenAI + + +class SambanovaProvider(Provider): + def __init__(self, **config): + """ + Initialize the SambaNova provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("SAMBANOVA_API_KEY")) + if not config["api_key"]: + raise ValueError( + "Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." + ) + + config["base_url"] = "https://api.sambanova.ai/v1/" + # Pass the entire config to the OpenAI client constructor + self.client = OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by Sambanova will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Sambanova API + ) diff --git a/aisuite/providers/watsonx_provider.py b/aisuite/providers/watsonx_provider.py new file mode 100644 index 00000000..5a4be042 --- /dev/null +++ b/aisuite/providers/watsonx_provider.py @@ -0,0 +1,39 @@ +from aisuite.provider import Provider +import os +from ibm_watsonx_ai import Credentials +from ibm_watsonx_ai.foundation_models import ModelInference +from aisuite.framework import ChatCompletionResponse + + +class WatsonxProvider(Provider): + def __init__(self, **config): + self.service_url = config.get("service_url") or os.getenv("WATSONX_SERVICE_URL") + self.api_key = config.get("api_key") or os.getenv("WATSONX_API_KEY") + self.project_id = config.get("project_id") or os.getenv("WATSONX_PROJECT_ID") + + if not self.service_url or not self.api_key or not self.project_id: + raise EnvironmentError( + "Missing one or more required WatsonX environment variables: " + "WATSONX_SERVICE_URL, WATSONX_API_KEY, WATSONX_PROJECT_ID. " + "Please refer to the setup guide: /guides/watsonx.md." + ) + + def chat_completions_create(self, model, messages, **kwargs): + model = ModelInference( + model_id=model, + credentials=Credentials( + api_key=self.api_key, + url=self.service_url, + ), + project_id=self.project_id, + ) + + res = model.chat(messages=messages, params=kwargs) + return self.normalize_response(res) + + def normalize_response(self, response): + openai_response = ChatCompletionResponse() + openai_response.choices[0].message.content = response["choices"][0]["message"][ + "content" + ] + return openai_response diff --git a/aisuite/providers/xai_provider.py b/aisuite/providers/xai_provider.py new file mode 100644 index 00000000..53e8d831 --- /dev/null +++ b/aisuite/providers/xai_provider.py @@ -0,0 +1,65 @@ +import os +import httpx +from aisuite.provider import Provider, LLMError +from aisuite.framework import ChatCompletionResponse + + +class XaiProvider(Provider): + """ + xAI Provider using httpx for direct API calls. + """ + + BASE_URL = "https://api.x.ai/v1/chat/completions" + + def __init__(self, **config): + """ + Initialize the xAI provider with the given configuration. + The API key is fetched from the config or environment variables. + """ + self.api_key = config.get("api_key", os.getenv("XAI_API_KEY")) + if not self.api_key: + raise ValueError( + "xAI API key is missing. Please provide it in the config or set the XAI_API_KEY environment variable." + ) + + # Optionally set a custom timeout (default to 30s) + self.timeout = config.get("timeout", 30) + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the xAI chat completions endpoint using httpx. + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + data = { + "model": model, + "messages": messages, + **kwargs, # Pass any additional arguments to the API + } + + try: + # Make the request to xAI endpoint. + response = httpx.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + except httpx.HTTPStatusError as http_err: + raise LLMError(f"xAI request failed: {http_err}") + except Exception as e: + raise LLMError(f"An error occurred: {e}") + + # Return the normalized response + return self._normalize_response(response.json()) + + def _normalize_response(self, response_data): + """ + Normalize the response to a common format (ChatCompletionResponse). + """ + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response_data["choices"][0][ + "message" + ]["content"] + return normalized_response diff --git a/examples/QnA_with_pdf.ipynb b/examples/QnA_with_pdf.ipynb index 4fbf0ba0..bfcb8b78 100644 --- a/examples/QnA_with_pdf.ipynb +++ b/examples/QnA_with_pdf.ipynb @@ -102,7 +102,6 @@ "metadata": {}, "outputs": [], "source": [ - "import aisuite as ai\n", "client = ai.Client()\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant. Answer the question only based on the below text.\"},\n", @@ -180,6 +179,25 @@ "print(response.choices[0].message.content)" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Yes, the price of organic avocados is higher than non-organic avocados. According to the text, the average price of organic avocados is generally 35-40% higher than conventional avocados.\n" + ] + } + ], + "source": [ + "nebius_model = \"nebius:meta-llama/Meta-Llama-3.1-8B-Instruct-fast\"\n", + "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", + "print(response.choices[0].message.content)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/client.ipynb b/examples/client.ipynb index 4e7e3182..839f06e4 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -61,6 +61,7 @@ " 'AWS_ACCESS_KEY_ID': 'xxx',\n", " 'AWS_SECRET_ACCESS_KEY': 'xxx',\n", " 'ANTHROPIC_API_KEY': 'xxx',\n", + " 'NEBIUS_API_KEY': 'xxx',\n", "}\n", "\n", "# Configure environment\n", @@ -122,7 +123,7 @@ "source": [ "# IMP NOTE: Azure expects model endpoint to be passed in the format of \"azure:\".\n", "# The model name is the deployment name in Project/Deployments.\n", - "# In the exmaple below, the model is \"mistral-large-2407\", but the name given to the\n", + "# In the example below, the model is \"mistral-large-2407\", but the name given to the\n", "# deployment is \"aisuite-mistral-large-2407\" under the deployments section in Azure.\n", "client.configure({\"azure\" : {\n", " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", @@ -142,7 +143,7 @@ "source": [ "# HuggingFace expects the model to be passed in the format of \"huggingface:\".\n", "# The model name is the full name of the model in HuggingFace.\n", - "# In the exmaple below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n", + "# In the example below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n", "# The model is deployed as serverless inference endpoint in HuggingFace.\n", "hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n", "response = client.chat.completions.create(model=hf_model, messages=messages)\n", @@ -159,7 +160,7 @@ "\n", "# Groq expects the model to be passed in the format of \"groq:\".\n", "# The model name is the full name of the model in Groq.\n", - "# In the exmaple below, the model is \"llama3-8b-8192\".\n", + "# In the example below, the model is \"llama3-8b-8192\".\n", "groq_llama3_8b = \"groq:llama3-8b-8192\"\n", "# groq_llama3_70b = \"groq:llama3-70b-8192\"\n", "response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n", @@ -208,6 +209,18 @@ "print(response.choices[0].message.content)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f38d033a-a580-4239-9176-27f3d53e7fe1", + "metadata": {}, + "outputs": [], + "source": [ + "nebius_model = \"nebius:Qwen/Qwen2.5-1.5B-Instruct\"\n", + "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", + "print(response.choices[0].message.content)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -266,4 +279,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/guides/README.md b/guides/README.md index 3079c29c..50774586 100644 --- a/guides/README.md +++ b/guides/README.md @@ -2,13 +2,17 @@ These guides give directions for obtaining API keys from different providers. -Here're the instructions for: +Here are the instructions for: - [Anthropic](anthropic.md) - [AWS](aws.md) - [Azure](azure.md) +- [Cohere](cohere.md) - [Google](google.md) - [Hugging Face](huggingface.md) - [OpenAI](openai.md) +- [SambaNova](sambanova.md) +- [xAI](xai.md) +- [DeepSeek](deepseek.md) Unless otherwise stated, these guides have not been endorsed by the providers. diff --git a/guides/anthropic.md b/guides/anthropic.md index 8d70cb5e..0f674c17 100644 --- a/guides/anthropic.md +++ b/guides/anthropic.md @@ -44,4 +44,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/aws.md b/guides/aws.md index a01d6eb6..531117ca 100644 --- a/guides/aws.md +++ b/guides/aws.md @@ -23,9 +23,9 @@ Once that has been enabled set your Access Key and Secret in the env variables: ```shell export AWS_ACCESS_KEY="your-access-key" export AWS_SECRET_KEY="your-secret-key" -export AWS_REGION_NAME="region-name" +export AWS_REGION="region-name" ``` -*Note: AWS_REGION_NAME is optional, a default of `us-west-2` has been set for easy of use* +*Note: AWS_REGION is optional, a default of `us-west-2` has been set for easy of use* ## Create a Chat Completion @@ -63,7 +63,7 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/azure.md b/guides/azure.md index e9a71fe0..8246b7ad 100644 --- a/guides/azure.md +++ b/guides/azure.md @@ -56,3 +56,5 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` + +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). \ No newline at end of file diff --git a/guides/cohere.md b/guides/cohere.md new file mode 100644 index 00000000..4f7320cf --- /dev/null +++ b/guides/cohere.md @@ -0,0 +1,44 @@ +# Cohere + +To use Cohere with `aisuite`, you’ll need an [Cohere account](https://cohere.com/). After logging in, go to the [API Keys](https://dashboard.cohere.com/api-keys) section in your account settings, agree to the terms of service, connect your card, and generate a new key. Once you have your key, add it to your environment as follows: + +```shell +export CO_API_KEY="your-cohere-api-key" +``` + +## Create a Chat Completion + +Install the `cohere` Python client: + +Example with pip: +```shell +pip install cohere +``` + +Example with poetry: +```shell +poetry add cohere +``` + +In your code: +```python +import aisuite as ai + +client = ai.Client() + +provider = "cohere" +model_id = "command-r-plus-08-2024" + +messages = [ + {"role": "user", "content": "Hi, how are you?"} +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/guides/deepseek.md b/guides/deepseek.md new file mode 100644 index 00000000..9985a11f --- /dev/null +++ b/guides/deepseek.md @@ -0,0 +1,46 @@ +# DeepSeek + +To use DeepSeek with `aisuite`, you’ll need an [DeepSeek account](https://platform.deepseek.com). After logging in, go to the [API Keys](https://platform.deepseek.com/api_keys) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows: + +```shell +export DEEPSEEK_API_KEY="your-deepseek-api-key" +``` + +## Create a Chat Completion + +(Note: The DeepSeek uses an API format consistent with OpenAI, hence why we need to install OpenAI, there is no DeepSeek Library at least not for now) + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "deepseek" +model_id = "deepseek-chat" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/google.md b/guides/google.md index eb351bd0..e357679e 100644 --- a/guides/google.md +++ b/guides/google.md @@ -22,7 +22,7 @@ Set the `GOOGLE_PROJECT_ID` environment variable to the ID of your project. You ### Set your preferred region in an environment variable. -Set the `GOOGLE_REGION` environment variable to the ID of your project. You can find the Project ID by visiting the project dashboard in the "Project Info" section toward the top of the page. +Set the `GOOGLE_REGION` environment variable. You can find the region by going to Project Dashboard under VertexAI side navigation menu, and then scrolling to the bottom of the page. ## Create a Service Account For API Access @@ -74,7 +74,7 @@ In your code: import aisuite as ai client = ai.Client() -model="vertex:gemini-1.5-pro-001" +model="google:gemini-1.5-pro-001" messages = [ {"role": "system", "content": "Respond in Pirate English."}, @@ -89,4 +89,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). \ No newline at end of file +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/groq.md b/guides/groq.md new file mode 100644 index 00000000..96b50f1e --- /dev/null +++ b/guides/groq.md @@ -0,0 +1,39 @@ +# Groq + +To use Groq with `aisuite`, you’ll need a free [Groq account](https://console.groq.com/). After logging in, go to the [API Keys](https://console.groq.com/keys) section in your account settings and generate a new Groq API key. Once you have your key, add it to your environment as follows: + +```shell +export GROQ_API_KEY="your-groq-api-key" +``` + +## Create a Python Chat Completion + +1. First, install the `groq` Python client library: + +```shell +pip install groq +``` + +2. Now you can simply create your first chat completion with the following example code or customize by swapoping out the `model_id` with any of the other available [models powered by Groq](https://console.groq.com/docs/models) and `messages` array with whatever you'd like: +```python +import aisuite as ai +client = ai.Client() + +provider = "groq" +model_id = "llama-3.2-3b-preview" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/guides/huggingface.md b/guides/huggingface.md index 11bd9297..029fde45 100644 --- a/guides/huggingface.md +++ b/guides/huggingface.md @@ -18,7 +18,7 @@ After setting up your model, you'll need to gather the following information: Set the following environment variables to make authentication and requests easy: ```shell -export HUGGINGFACE_TOKEN="your-api-token" +export HF_TOKEN="your-api-token" ``` ## Create a Chat Completion @@ -53,3 +53,5 @@ print(response.choices[0].message.content) - Ensure that the `model` variable matches the identifier of your model as seen in the Hugging Face Model Hub. - If you encounter any rate limits or API access restrictions, you may have to upgrade your Hugging Face plan to enable higher usage limits. """ + +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). \ No newline at end of file diff --git a/guides/nebius.md b/guides/nebius.md new file mode 100644 index 00000000..2343d503 --- /dev/null +++ b/guides/nebius.md @@ -0,0 +1,44 @@ +# Nebius AI Studio + +To use Nebius AI Studio with `aisuite`, you need an AI Studio account. Go to [AI Studio](https://studio.nebius.ai/) and press "Log in to AI Studio" in the right top corner. After logging in, go to the [API Keys](https://studio.nebius.ai/settings/api-keys) section and generate a new key. Once you have a key, add it to your environment as follows: + +```shell +export NEBIUS_API_KEY="your-nebius-api-key" +``` + +## Create a Chat Completion + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "nebius" +model_id = "meta-llama/Llama-3.3-70B-Instruct" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "How many times has Jurgen Klopp won the Champions League?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/guides/openai.md b/guides/openai.md index 6dc9ce97..ab297490 100644 --- a/guides/openai.md +++ b/guides/openai.md @@ -41,4 +41,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). +Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/sambanova.md b/guides/sambanova.md new file mode 100644 index 00000000..6b331c2f --- /dev/null +++ b/guides/sambanova.md @@ -0,0 +1,44 @@ +# Sambanova + +To use Sambanova with `aisuite`, you’ll need a [Sambanova Cloud](https://cloud.sambanova.ai/) account. After logging in, go to the [API](https://cloud.sambanova.ai/apis) section and generate a new key. Once you have your key, add it to your environment as follows: + +```shell +export SAMBANOVA_API_KEY="your-sambanova-api-key" +``` + +## Create a Chat Completion + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "sambanova" +model_id = "Meta-Llama-3.1-405B-Instruct" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/guides/watsonx.md b/guides/watsonx.md new file mode 100644 index 00000000..c2d4121b --- /dev/null +++ b/guides/watsonx.md @@ -0,0 +1,83 @@ +# Watsonx with `aisuite` + +A a step-by-step guide to set up Watsonx with the `aisuite` library, enabling you to use IBM Watsonx's powerful AI models for various tasks. + +## Setup Instructions + +### Step 1: Create a Watsonx Account + +1. Visit [IBM Watsonx](https://www.ibm.com/watsonx). +2. Sign up for a new account or log in with your existing IBM credentials. +3. Once logged in, navigate to the **Watsonx Dashboard** () + +--- + +### Step 2: Obtain API Credentials + +1. **Generate an API Key**: + - Go to IAM > API keys and create a new API key () + - Copy the API key. This is your `WATSONX_API_KEY`. + +2. **Locate the Service URL**: + - Your service URL is based on the region where your service is hosted. + - Pick one from the list here + - Copy the service URL. This is your `WATSONX_SERVICE_URL`. + +3. **Get the Project ID**: + - Go to the **Watsonx Dashboard** () + - Under the **Projects** section, If you don't have a sandbox project, create a new project. + - Navigate to the **Manage** tab and find the **Project ID**. + - Copy the **Project ID**. This will serve as your `WATSONX_PROJECT_ID`. + +--- + +### Step 3: Set Environment Variables + +To simplify authentication, set the following environment variables: + +Run the following commands in your terminal: + +```bash +export WATSONX_API_KEY="your-watsonx-api-key" +export WATSONX_SERVICE_URL="your-watsonx-service-url" +export WATSONX_PROJECT_ID="your-watsonx-project-id" +``` + + +## Create a Chat Completion + +Install the `ibm-watsonx-ai` Python client: + +Example with pip: + +```shell +pip install ibm-watsonx-ai +``` + +Example with poetry: + +```shell +poetry add ibm-watsonx-ai +``` + +In your code: + +```python +import aisuite as ai +client = ai.Client() + +provider = "watsonx" +model_id = "meta-llama/llama-3-70b-instruct" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` \ No newline at end of file diff --git a/guides/xai.md b/guides/xai.md new file mode 100644 index 00000000..7129dd99 --- /dev/null +++ b/guides/xai.md @@ -0,0 +1,33 @@ +# xAI + +To use xAI with `aisuite`, you’ll need an [API key](https://console.x.ai/). Generate a new key and once you have your key, add it to your environment as follows: + +```shell +export XAI_API_KEY="your-xai-api-key" +``` + +## Create a Chat Completion + +Sample code: +```python +import aisuite as ai +client = ai.Client() + +models = ["xai:grok-beta"] + +messages = [ + {"role": "system", "content": "Respond in Pirate English."}, + {"role": "user", "content": "Tell me a joke."}, +] + +for model in models: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0.75 + ) + print(response.choices[0].message.content) + +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/poetry.lock b/poetry.lock index 27aa1cc1..c55d3a27 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohttp" @@ -838,6 +838,33 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cohere" +version = "5.13.3" +description = "" +optional = true +python-versions = "<4.0,>=3.9" +files = [ + {file = "cohere-5.13.3-py3-none-any.whl", hash = "sha256:076c88fdd3d670b6577eb8e813a9072bf18b59648d4092c6f0263af3c27bf81f"}, + {file = "cohere-5.13.3.tar.gz", hash = "sha256:70d87e0d5ce48aaee5ba70ead5efbade226cb2a4b11bfcfb676f6a2db3642819"}, +] + +[package.dependencies] +fastavro = ">=1.9.4,<2.0.0" +httpx = ">=0.21.2" +httpx-sse = "0.4.0" +numpy = ">=1.26,<2.0" +parameterized = ">=0.9.0,<0.10.0" +pydantic = ">=1.9.2" +pydantic-core = ">=2.18.2,<3.0.0" +requests = ">=2.0.0,<3.0.0" +tokenizers = ">=0.15,<1" +types-requests = ">=2.0.0,<3.0.0" +typing_extensions = ">=4.0.0" + +[package.extras] +aws = ["boto3 (>=1.34.0,<2.0.0)", "sagemaker (>=2.232.1,<3.0.0)"] + [[package]] name = "colorama" version = "0.4.6" @@ -883,6 +910,83 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "coverage" +version = "7.6.8" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "coverage-7.6.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b39e6011cd06822eb964d038d5dff5da5d98652b81f5ecd439277b32361a3a50"}, + {file = "coverage-7.6.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:63c19702db10ad79151a059d2d6336fe0c470f2e18d0d4d1a57f7f9713875dcf"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3985b9be361d8fb6b2d1adc9924d01dec575a1d7453a14cccd73225cb79243ee"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:644ec81edec0f4ad17d51c838a7d01e42811054543b76d4ba2c5d6af741ce2a6"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f188a2402f8359cf0c4b1fe89eea40dc13b52e7b4fd4812450da9fcd210181d"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e19122296822deafce89a0c5e8685704c067ae65d45e79718c92df7b3ec3d331"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:13618bed0c38acc418896005732e565b317aa9e98d855a0e9f211a7ffc2d6638"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:193e3bffca48ad74b8c764fb4492dd875038a2f9925530cb094db92bb5e47bed"}, + {file = "coverage-7.6.8-cp310-cp310-win32.whl", hash = "sha256:3988665ee376abce49613701336544041f2117de7b7fbfe91b93d8ff8b151c8e"}, + {file = "coverage-7.6.8-cp310-cp310-win_amd64.whl", hash = "sha256:f56f49b2553d7dd85fd86e029515a221e5c1f8cb3d9c38b470bc38bde7b8445a"}, + {file = "coverage-7.6.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86cffe9c6dfcfe22e28027069725c7f57f4b868a3f86e81d1c62462764dc46d4"}, + {file = "coverage-7.6.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d82ab6816c3277dc962cfcdc85b1efa0e5f50fb2c449432deaf2398a2928ab94"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13690e923a3932e4fad4c0ebfb9cb5988e03d9dcb4c5150b5fcbf58fd8bddfc4"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4be32da0c3827ac9132bb488d331cb32e8d9638dd41a0557c5569d57cf22c9c1"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e6c85bbdc809383b509d732b06419fb4544dca29ebe18480379633623baafb"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:768939f7c4353c0fac2f7c37897e10b1414b571fd85dd9fc49e6a87e37a2e0d8"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e44961e36cb13c495806d4cac67640ac2866cb99044e210895b506c26ee63d3a"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3ea8bb1ab9558374c0ab591783808511d135a833c3ca64a18ec927f20c4030f0"}, + {file = "coverage-7.6.8-cp311-cp311-win32.whl", hash = "sha256:629a1ba2115dce8bf75a5cce9f2486ae483cb89c0145795603d6554bdc83e801"}, + {file = "coverage-7.6.8-cp311-cp311-win_amd64.whl", hash = "sha256:fb9fc32399dca861584d96eccd6c980b69bbcd7c228d06fb74fe53e007aa8ef9"}, + {file = "coverage-7.6.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e683e6ecc587643f8cde8f5da6768e9d165cd31edf39ee90ed7034f9ca0eefee"}, + {file = "coverage-7.6.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1defe91d41ce1bd44b40fabf071e6a01a5aa14de4a31b986aa9dfd1b3e3e414a"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7ad66e8e50225ebf4236368cc43c37f59d5e6728f15f6e258c8639fa0dd8e6d"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fe47da3e4fda5f1abb5709c156eca207eacf8007304ce3019eb001e7a7204cb"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:202a2d645c5a46b84992f55b0a3affe4f0ba6b4c611abec32ee88358db4bb649"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4674f0daa1823c295845b6a740d98a840d7a1c11df00d1fd62614545c1583787"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:74610105ebd6f33d7c10f8907afed696e79c59e3043c5f20eaa3a46fddf33b4c"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37cda8712145917105e07aab96388ae76e787270ec04bcb9d5cc786d7cbb8443"}, + {file = "coverage-7.6.8-cp312-cp312-win32.whl", hash = "sha256:9e89d5c8509fbd6c03d0dd1972925b22f50db0792ce06324ba069f10787429ad"}, + {file = "coverage-7.6.8-cp312-cp312-win_amd64.whl", hash = "sha256:379c111d3558272a2cae3d8e57e6b6e6f4fe652905692d54bad5ea0ca37c5ad4"}, + {file = "coverage-7.6.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b0c69f4f724c64dfbfe79f5dfb503b42fe6127b8d479b2677f2b227478db2eb"}, + {file = "coverage-7.6.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c15b32a7aca8038ed7644f854bf17b663bc38e1671b5d6f43f9a2b2bd0c46f63"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63068a11171e4276f6ece913bde059e77c713b48c3a848814a6537f35afb8365"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f4548c5ead23ad13fb7a2c8ea541357474ec13c2b736feb02e19a3085fac002"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b4b4299dd0d2c67caaaf286d58aef5e75b125b95615dda4542561a5a566a1e3"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c9ebfb2507751f7196995142f057d1324afdab56db1d9743aab7f50289abd022"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c1b4474beee02ede1eef86c25ad4600a424fe36cff01a6103cb4533c6bf0169e"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d9fd2547e6decdbf985d579cf3fc78e4c1d662b9b0ff7cc7862baaab71c9cc5b"}, + {file = "coverage-7.6.8-cp313-cp313-win32.whl", hash = "sha256:8aae5aea53cbfe024919715eca696b1a3201886ce83790537d1c3668459c7146"}, + {file = "coverage-7.6.8-cp313-cp313-win_amd64.whl", hash = "sha256:ae270e79f7e169ccfe23284ff5ea2d52a6f401dc01b337efb54b3783e2ce3f28"}, + {file = "coverage-7.6.8-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:de38add67a0af869b0d79c525d3e4588ac1ffa92f39116dbe0ed9753f26eba7d"}, + {file = "coverage-7.6.8-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b07c25d52b1c16ce5de088046cd2432b30f9ad5e224ff17c8f496d9cb7d1d451"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62a66ff235e4c2e37ed3b6104d8b478d767ff73838d1222132a7a026aa548764"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09b9f848b28081e7b975a3626e9081574a7b9196cde26604540582da60235fdf"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:093896e530c38c8e9c996901858ac63f3d4171268db2c9c8b373a228f459bbc5"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9a7b8ac36fd688c8361cbc7bf1cb5866977ece6e0b17c34aa0df58bda4fa18a4"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:38c51297b35b3ed91670e1e4efb702b790002e3245a28c76e627478aa3c10d83"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2e4e0f60cb4bd7396108823548e82fdab72d4d8a65e58e2c19bbbc2f1e2bfa4b"}, + {file = "coverage-7.6.8-cp313-cp313t-win32.whl", hash = "sha256:6535d996f6537ecb298b4e287a855f37deaf64ff007162ec0afb9ab8ba3b8b71"}, + {file = "coverage-7.6.8-cp313-cp313t-win_amd64.whl", hash = "sha256:c79c0685f142ca53256722a384540832420dff4ab15fec1863d7e5bc8691bdcc"}, + {file = "coverage-7.6.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ac47fa29d8d41059ea3df65bd3ade92f97ee4910ed638e87075b8e8ce69599e"}, + {file = "coverage-7.6.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:24eda3a24a38157eee639ca9afe45eefa8d2420d49468819ac5f88b10de84f4c"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4c81ed2820b9023a9a90717020315e63b17b18c274a332e3b6437d7ff70abe0"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd55f8fc8fa494958772a2a7302b0354ab16e0b9272b3c3d83cdb5bec5bd1779"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f39e2f3530ed1626c66e7493be7a8423b023ca852aacdc91fb30162c350d2a92"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:716a78a342679cd1177bc8c2fe957e0ab91405bd43a17094324845200b2fddf4"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:177f01eeaa3aee4a5ffb0d1439c5952b53d5010f86e9d2667963e632e30082cc"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:912e95017ff51dc3d7b6e2be158dedc889d9a5cc3382445589ce554f1a34c0ea"}, + {file = "coverage-7.6.8-cp39-cp39-win32.whl", hash = "sha256:4db3ed6a907b555e57cc2e6f14dc3a4c2458cdad8919e40b5357ab9b6db6c43e"}, + {file = "coverage-7.6.8-cp39-cp39-win_amd64.whl", hash = "sha256:428ac484592f780e8cd7b6b14eb568f7c85460c92e2a37cb0c0e5186e1a0d076"}, + {file = "coverage-7.6.8-pp39.pp310-none-any.whl", hash = "sha256:5c52a036535d12590c32c49209e79cabaad9f9ad8aa4cbd875b68c4d67a9cbce"}, + {file = "coverage-7.6.8.tar.gz", hash = "sha256:8b2b8503edb06822c86d82fa64a4a5cb0760bb8f31f26e138ec743f422f37cfc"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + [[package]] name = "datasets" version = "2.20.0" @@ -1150,6 +1254,52 @@ typer = ">=0.12.3" [package.extras] standard = ["fastapi", "uvicorn[standard] (>=0.15.0)"] +[[package]] +name = "fastavro" +version = "1.9.7" +description = "Fast read/write of AVRO files" +optional = true +python-versions = ">=3.8" +files = [ + {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb8749e419a85f251bf1ac87d463311874972554d25d4a0b19f6bdc56036d7cf"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b2f9bafa167cb4d1c3dd17565cb5bf3d8c0759e42620280d1760f1e778e07fc"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e87d04b235b29f7774d226b120da2ca4e60b9e6fdf6747daef7f13f218b3517a"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b525c363e267ed11810aaad8fbdbd1c3bd8837d05f7360977d72a65ab8c6e1fa"}, + {file = "fastavro-1.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:6312fa99deecc319820216b5e1b1bd2d7ebb7d6f221373c74acfddaee64e8e60"}, + {file = "fastavro-1.9.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ec8499dc276c2d2ef0a68c0f1ad11782b2b956a921790a36bf4c18df2b8d4020"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d9d96f98052615ab465c63ba8b76ed59baf2e3341b7b169058db104cbe2aa0"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919f3549e07a8a8645a2146f23905955c35264ac809f6c2ac18142bc5b9b6022"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9de1fa832a4d9016724cd6facab8034dc90d820b71a5d57c7e9830ffe90f31e4"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1d09227d1f48f13281bd5ceac958650805aef9a4ef4f95810128c1f9be1df736"}, + {file = "fastavro-1.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:2db993ae6cdc63e25eadf9f93c9e8036f9b097a3e61d19dca42536dcc5c4d8b3"}, + {file = "fastavro-1.9.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4e1289b731214a7315884c74b2ec058b6e84380ce9b18b8af5d387e64b18fc44"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eac69666270a76a3a1d0444f39752061195e79e146271a568777048ffbd91a27"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9be089be8c00f68e343bbc64ca6d9a13e5e5b0ba8aa52bcb231a762484fb270e"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d576eccfd60a18ffa028259500df67d338b93562c6700e10ef68bbd88e499731"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ee9bf23c157bd7dcc91ea2c700fa3bd924d9ec198bb428ff0b47fa37fe160659"}, + {file = "fastavro-1.9.7-cp312-cp312-win_amd64.whl", hash = "sha256:b6b2ccdc78f6afc18c52e403ee68c00478da12142815c1bd8a00973138a166d0"}, + {file = "fastavro-1.9.7-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:7313def3aea3dacface0a8b83f6d66e49a311149aa925c89184a06c1ef99785d"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:536f5644737ad21d18af97d909dba099b9e7118c237be7e4bd087c7abde7e4f0"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2af559f30383b79cf7d020a6b644c42ffaed3595f775fe8f3d7f80b1c43dfdc5"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:edc28ab305e3c424de5ac5eb87b48d1e07eddb6aa08ef5948fcda33cc4d995ce"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ec2e96bdabd58427fe683329b3d79f42c7b4f4ff6b3644664a345a655ac2c0a1"}, + {file = "fastavro-1.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:3b683693c8a85ede496ebebe115be5d7870c150986e34a0442a20d88d7771224"}, + {file = "fastavro-1.9.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:58f76a5c9a312fbd37b84e49d08eb23094d36e10d43bc5df5187bc04af463feb"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56304401d2f4f69f5b498bdd1552c13ef9a644d522d5de0dc1d789cf82f47f73"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fcce036c6aa06269fc6a0428050fcb6255189997f5e1a728fc461e8b9d3e26b"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:17de68aae8c2525f5631d80f2b447a53395cdc49134f51b0329a5497277fc2d2"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7c911366c625d0a997eafe0aa83ffbc6fd00d8fd4543cb39a97c6f3b8120ea87"}, + {file = "fastavro-1.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:912283ed48578a103f523817fdf0c19b1755cea9b4a6387b73c79ecb8f8f84fc"}, + {file = "fastavro-1.9.7.tar.gz", hash = "sha256:13e11c6cb28626da85290933027cd419ce3f9ab8e45410ef24ce6b89d20a1f6c"}, +] + +[package.extras] +codecs = ["cramjam", "lz4", "zstandard"] +lz4 = ["lz4"] +snappy = ["cramjam"] +zstandard = ["zstandard"] + [[package]] name = "fastjsonschema" version = "2.20.0" @@ -1845,13 +1995,13 @@ test = ["Cython (>=0.29.24,<0.30.0)"] [[package]] name = "httpx" -version = "0.27.0" +version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] [package.dependencies] @@ -1866,6 +2016,7 @@ brotli = ["brotli", "brotlicffi"] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "httpx-sse" @@ -1926,6 +2077,80 @@ files = [ [package.dependencies] pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} +[[package]] +name = "ibm-cos-sdk" +version = "2.13.6" +description = "IBM SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ibm-cos-sdk-2.13.6.tar.gz", hash = "sha256:171cf2ae4ab662a4b8ab58dcf4ac994b0577d6c92d78490295fd7704a83978f6"}, +] + +[package.dependencies] +ibm-cos-sdk-core = "2.13.6" +ibm-cos-sdk-s3transfer = "2.13.6" +jmespath = ">=0.10.0,<=1.0.1" + +[[package]] +name = "ibm-cos-sdk-core" +version = "2.13.6" +description = "Low-level, data-driven core of IBM SDK for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "ibm-cos-sdk-core-2.13.6.tar.gz", hash = "sha256:dd41fb789eeb65546501afabcd50e78846ab4513b6ad4042e410b6a14ff88413"}, +] + +[package.dependencies] +jmespath = ">=0.10.0,<=1.0.1" +python-dateutil = ">=2.9.0,<3.0.0" +requests = ">=2.32.0,<2.32.3" +urllib3 = ">=1.26.18,<3" + +[[package]] +name = "ibm-cos-sdk-s3transfer" +version = "2.13.6" +description = "IBM S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ibm-cos-sdk-s3transfer-2.13.6.tar.gz", hash = "sha256:e0acce6f380c47d11e07c6765b684b4ababbf5c66cc0503bc246469a1e2b9790"}, +] + +[package.dependencies] +ibm-cos-sdk-core = "2.13.6" + +[[package]] +name = "ibm-watsonx-ai" +version = "1.1.16" +description = "IBM watsonx.ai API Client" +optional = false +python-versions = ">=3.10" +files = [ + {file = "ibm_watsonx_ai-1.1.16-py3-none-any.whl", hash = "sha256:c703adda2588c85606f74c230afe3ce31202815de369301df19f14ce21bd093a"}, + {file = "ibm_watsonx_ai-1.1.16.tar.gz", hash = "sha256:ab79ed5dedd57fd574c5c6c5ceca50a89b8423562646255c910ed74a8d8811a5"}, +] + +[package.dependencies] +certifi = "*" +httpx = "*" +ibm-cos-sdk = ">=2.12.0,<2.14.0" +importlib-metadata = "*" +lomond = "*" +packaging = "*" +pandas = ">=0.24.2,<2.2.0" +requests = "*" +tabulate = "*" +urllib3 = "*" + +[package.extras] +fl-crypto = ["pyhelayers (==1.5.0.3)"] +fl-crypto-rt24-1 = ["pyhelayers (==1.5.3.1)"] +fl-rt23-1-py3-10 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.1.1)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.23.5)", "pandas (==1.5.3)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.1.1)", "scipy (==1.10.1)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.12.0)", "torch (==2.0.1)", "websockets (==10.1)"] +fl-rt24-1-py3-11 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.3.2)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.26.4)", "pandas (==2.1.4)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.3.0)", "scipy (==1.11.4)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.14.1)", "torch (==2.1.2)", "websockets (==10.1)"] +rag = ["beautifulsoup4 (==4.12.3)", "grpcio (>=1.60.0)", "langchain (>=0.2.15,<0.3)", "langchain-chroma (==0.1.1)", "langchain-community (>=0.2.4,<0.3)", "langchain-core (>=0.2.37,<0.3)", "langchain-elasticsearch (==0.2.2)", "langchain-ibm", "langchain-milvus (==0.1.1)", "markdown (==3.4.1)", "pypdf (==4.2.0)", "python-docx (==1.1.2)"] + [[package]] name = "identify" version = "2.6.0" @@ -2517,6 +2742,20 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "lomond" +version = "0.3.3" +description = "Websocket Client Library" +optional = false +python-versions = "*" +files = [ + {file = "lomond-0.3.3-py2.py3-none-any.whl", hash = "sha256:df1dd4dd7b802a12b71907ab1abb08b8ce9950195311207579379eb3b1553de7"}, + {file = "lomond-0.3.3.tar.gz", hash = "sha256:427936596b144b4ec387ead99aac1560b77c8a78107d3d49415d3abbe79acbd3"}, +] + +[package.dependencies] +six = ">=1.10.0" + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -3612,76 +3851,71 @@ files = [ [[package]] name = "pandas" -version = "2.2.2" +version = "2.1.4" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, - {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, - {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, - {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, - {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, - {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, + {file = "pandas-2.1.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bdec823dc6ec53f7a6339a0e34c68b144a7a1fd28d80c260534c39c62c5bf8c9"}, + {file = "pandas-2.1.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:294d96cfaf28d688f30c918a765ea2ae2e0e71d3536754f4b6de0ea4a496d034"}, + {file = "pandas-2.1.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b728fb8deba8905b319f96447a27033969f3ea1fea09d07d296c9030ab2ed1d"}, + {file = "pandas-2.1.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00028e6737c594feac3c2df15636d73ace46b8314d236100b57ed7e4b9ebe8d9"}, + {file = "pandas-2.1.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:426dc0f1b187523c4db06f96fb5c8d1a845e259c99bda74f7de97bd8a3bb3139"}, + {file = "pandas-2.1.4-cp310-cp310-win_amd64.whl", hash = "sha256:f237e6ca6421265643608813ce9793610ad09b40154a3344a088159590469e46"}, + {file = "pandas-2.1.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b7d852d16c270e4331f6f59b3e9aa23f935f5c4b0ed2d0bc77637a8890a5d092"}, + {file = "pandas-2.1.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7d5f2f54f78164b3d7a40f33bf79a74cdee72c31affec86bfcabe7e0789821"}, + {file = "pandas-2.1.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0aa6e92e639da0d6e2017d9ccff563222f4eb31e4b2c3cf32a2a392fc3103c0d"}, + {file = "pandas-2.1.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d797591b6846b9db79e65dc2d0d48e61f7db8d10b2a9480b4e3faaddc421a171"}, + {file = "pandas-2.1.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d2d3e7b00f703aea3945995ee63375c61b2e6aa5aa7871c5d622870e5e137623"}, + {file = "pandas-2.1.4-cp311-cp311-win_amd64.whl", hash = "sha256:dc9bf7ade01143cddc0074aa6995edd05323974e6e40d9dbde081021ded8510e"}, + {file = "pandas-2.1.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:482d5076e1791777e1571f2e2d789e940dedd927325cc3cb6d0800c6304082f6"}, + {file = "pandas-2.1.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8a706cfe7955c4ca59af8c7a0517370eafbd98593155b48f10f9811da440248b"}, + {file = "pandas-2.1.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0513a132a15977b4a5b89aabd304647919bc2169eac4c8536afb29c07c23540"}, + {file = "pandas-2.1.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9f17f2b6fc076b2a0078862547595d66244db0f41bf79fc5f64a5c4d635bead"}, + {file = "pandas-2.1.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:45d63d2a9b1b37fa6c84a68ba2422dc9ed018bdaa668c7f47566a01188ceeec1"}, + {file = "pandas-2.1.4-cp312-cp312-win_amd64.whl", hash = "sha256:f69b0c9bb174a2342818d3e2778584e18c740d56857fc5cdb944ec8bbe4082cf"}, + {file = "pandas-2.1.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3f06bda01a143020bad20f7a85dd5f4a1600112145f126bc9e3e42077c24ef34"}, + {file = "pandas-2.1.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab5796839eb1fd62a39eec2916d3e979ec3130509930fea17fe6f81e18108f6a"}, + {file = "pandas-2.1.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edbaf9e8d3a63a9276d707b4d25930a262341bca9874fcb22eff5e3da5394732"}, + {file = "pandas-2.1.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ebfd771110b50055712b3b711b51bee5d50135429364d0498e1213a7adc2be8"}, + {file = "pandas-2.1.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8ea107e0be2aba1da619cc6ba3f999b2bfc9669a83554b1904ce3dd9507f0860"}, + {file = "pandas-2.1.4-cp39-cp39-win_amd64.whl", hash = "sha256:d65148b14788b3758daf57bf42725caa536575da2b64df9964c563b015230984"}, + {file = "pandas-2.1.4.tar.gz", hash = "sha256:fcb68203c833cc735321512e13861358079a96c174a61f5116a1de89c58c0ef7"}, ] [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" -tzdata = ">=2022.7" +tzdata = ">=2022.1" [package.extras] -all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] -aws = ["s3fs (>=2022.11.0)"] -clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] -compression = ["zstandard (>=0.19.0)"] -computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +all = ["PyQt5 (>=5.15.6)", "SQLAlchemy (>=1.4.36)", "beautifulsoup4 (>=4.11.1)", "bottleneck (>=1.3.4)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=0.8.1)", "fsspec (>=2022.05.0)", "gcsfs (>=2022.05.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.8.0)", "matplotlib (>=3.6.1)", "numba (>=0.55.2)", "numexpr (>=2.8.0)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pandas-gbq (>=0.17.5)", "psycopg2 (>=2.9.3)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.5)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "pyxlsb (>=1.0.9)", "qtpy (>=2.2.0)", "s3fs (>=2022.05.0)", "scipy (>=1.8.1)", "tables (>=3.7.0)", "tabulate (>=0.8.10)", "xarray (>=2022.03.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)", "zstandard (>=0.17.0)"] +aws = ["s3fs (>=2022.05.0)"] +clipboard = ["PyQt5 (>=5.15.6)", "qtpy (>=2.2.0)"] +compression = ["zstandard (>=0.17.0)"] +computation = ["scipy (>=1.8.1)", "xarray (>=2022.03.0)"] consortium-standard = ["dataframe-api-compat (>=0.1.7)"] -excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] -feather = ["pyarrow (>=10.0.1)"] -fss = ["fsspec (>=2022.11.0)"] -gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] -hdf5 = ["tables (>=3.8.0)"] -html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] -mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] -output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] -parquet = ["pyarrow (>=10.0.1)"] -performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] -plot = ["matplotlib (>=3.6.3)"] -postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] -pyarrow = ["pyarrow (>=10.0.1)"] -spss = ["pyreadstat (>=1.2.0)"] -sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pyxlsb (>=1.0.9)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)"] +feather = ["pyarrow (>=7.0.0)"] +fss = ["fsspec (>=2022.05.0)"] +gcp = ["gcsfs (>=2022.05.0)", "pandas-gbq (>=0.17.5)"] +hdf5 = ["tables (>=3.7.0)"] +html = ["beautifulsoup4 (>=4.11.1)", "html5lib (>=1.1)", "lxml (>=4.8.0)"] +mysql = ["SQLAlchemy (>=1.4.36)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.8.10)"] +parquet = ["pyarrow (>=7.0.0)"] +performance = ["bottleneck (>=1.3.4)", "numba (>=0.55.2)", "numexpr (>=2.8.0)"] +plot = ["matplotlib (>=3.6.1)"] +postgresql = ["SQLAlchemy (>=1.4.36)", "psycopg2 (>=2.9.3)"] +spss = ["pyreadstat (>=1.1.5)"] +sql-other = ["SQLAlchemy (>=1.4.36)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] -xml = ["lxml (>=4.9.2)"] +xml = ["lxml (>=4.8.0)"] [[package]] name = "pandocfilters" @@ -3694,6 +3928,20 @@ files = [ {file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"}, ] +[[package]] +name = "parameterized" +version = "0.9.0" +description = "Parameterized testing with any Python test framework" +optional = true +python-versions = ">=3.7" +files = [ + {file = "parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b"}, + {file = "parameterized-0.9.0.tar.gz", hash = "sha256:7fc905272cefa4f364c1a3429cbbe9c0f98b793988efb5bf90aac80f08db09b1"}, +] + +[package.extras] +dev = ["jinja2"] + [[package]] name = "parso" version = "0.8.4" @@ -4311,6 +4559,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-cov" +version = "6.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0"}, + {file = "pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35"}, +] + +[package.dependencies] +coverage = {version = ">=7.5", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -4678,13 +4944,13 @@ files = [ [[package]] name = "requests" -version = "2.32.3" +version = "2.32.2" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, - {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] @@ -5322,6 +5588,20 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "tenacity" version = "8.5.0" @@ -5747,6 +6027,20 @@ files = [ {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, ] +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = true +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.12.2" @@ -6478,18 +6772,20 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "groq", "openai"] +all = ["anthropic", "cohere", "groq", "openai"] anthropic = ["anthropic"] aws = ["boto3"] azure = [] +cohere = ["cohere"] google = ["vertexai"] groq = ["groq"] huggingface = [] mistral = ["mistralai"] ollama = [] openai = ["openai"] +watsonx = ["ibm-watsonx-ai"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e455ddfd148ffae77e7b9ac196792b38c7cc545c0866cca2aaae227f4a4201df" +content-hash = "522cba517a2a3cc94bdee88c3af849d76bc00dc9d777d38a40cda860e7c108cb" diff --git a/pyproject.toml b/pyproject.toml index 2a009f74..5d0aa836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,35 +1,39 @@ [tool.poetry] name = "aisuite" -version = "0.1.6" +version = "0.1.7" description = "Uniform access layer for LLMs" -authors = ["Andrew Ng"] +authors = ["Andrew Ng, Rohit P"] readme = "README.md" [tool.poetry.dependencies] python = "^3.10" anthropic = { version = "^0.30.1", optional = true } boto3 = { version = "^1.34.144", optional = true } +cohere = { version = "^5.12.0", optional = true } vertexai = { version = "^1.63.0", optional = true } groq = { version = "^0.9.0", optional = true } mistralai = { version = "^1.0.3", optional = true } openai = { version = "^1.35.8", optional = true } docstring-parser = { version = "^0.14.0", optional = true } +ibm-watsonx-ai = { version = "^1.1.16", optional = true } # Optional dependencies for different providers +httpx = "~0.27.0" [tool.poetry.extras] anthropic = ["anthropic"] aws = ["boto3"] azure = [] +cohere = ["cohere"] google = ["vertexai"] groq = ["groq"] huggingface = [] mistral = ["mistralai"] ollama = [] openai = ["openai"] -all = ["anthropic", "aws", "google", "groq", "mistral", "openai"] # To install all providers +watsonx = ["ibm-watsonx-ai"] +all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers [tool.poetry.group.dev.dependencies] -pytest = "^8.2.2" pre-commit = "^3.7.1" black = "^24.4.2" python-dotenv = "^1.0.1" @@ -45,7 +49,21 @@ chromadb = "^0.5.4" sentence-transformers = "^3.0.1" datasets = "^2.20.0" vertexai = "^1.63.0" +ibm-watsonx-ai = "^1.1.16" + +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "^8.2.2" +pytest-cov = "^6.0.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +testpaths="tests" +markers = [ + "integration: marks tests as integration tests that interact with external services", +] diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e1949ac..a94b139d 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,72 +1,110 @@ -import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch + +import pytest + from aisuite import Client -class TestClient(unittest.TestCase): - @patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create") - @patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create") - @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") - @patch("aisuite.providers.aws_provider.AwsProvider.chat_completions_create") - @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") - @patch( - "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" - ) - @patch("aisuite.providers.google_provider.GoogleProvider.chat_completions_create") - @patch( - "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create" - ) - def test_client_chat_completions( - self, - mock_fireworks, - mock_google, - mock_anthropic, - mock_azure, - mock_bedrock, - mock_openai, - mock_groq, - mock_mistral, - ): - # Mock responses from providers - mock_openai.return_value = "OpenAI Response" - mock_bedrock.return_value = "AWS Bedrock Response" - mock_azure.return_value = "Azure Response" - mock_anthropic.return_value = "Anthropic Response" - mock_groq.return_value = "Groq Response" - mock_mistral.return_value = "Mistral Response" - mock_google.return_value = "Google Response" - mock_fireworks.return_value = "Fireworks Response" - - # Provider configurations - provider_configs = { - "openai": {"api_key": "test_openai_api_key"}, - "aws": { - "aws_access_key": "test_aws_access_key", - "aws_secret_key": "test_aws_secret_key", - "aws_session_token": "test_aws_session_token", - "aws_region": "us-west-2", - }, - "azure": { - "api_key": "azure-api-key", - "base_url": "https://model.ai.azure.com", - }, - "groq": { - "api_key": "groq-api-key", - }, - "mistral": { - "api_key": "mistral-api-key", - }, - "google": { - "project_id": "test_google_project_id", - "region": "us-west4", - "application_credentials": "test_google_application_credentials", - }, - "fireworks": { - "api_key": "fireworks-api-key", - }, - } - - # Initialize the client +@pytest.fixture(scope="module") +def provider_configs(): + return { + "openai": {"api_key": "test_openai_api_key"}, + "aws": { + "aws_access_key": "test_aws_access_key", + "aws_secret_key": "test_aws_secret_key", + "aws_session_token": "test_aws_session_token", + "aws_region": "us-west-2", + }, + "azure": { + "api_key": "azure-api-key", + "base_url": "https://model.ai.azure.com", + }, + "groq": { + "api_key": "groq-api-key", + }, + "mistral": { + "api_key": "mistral-api-key", + }, + "google": { + "project_id": "test_google_project_id", + "region": "us-west4", + "application_credentials": "test_google_application_credentials", + }, + "fireworks": { + "api_key": "fireworks-api-key", + }, + "nebius": { + "api_key": "nebius-api-key", + }, + "watsonx": { + "service_url": "https://watsonx-service-url.com", + "api_key": "watsonx-api-key", + "project_id": "watsonx-project-id", + }, + } + + +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create", + "openai", + "gpt-4o", + ), + ( + "aisuite.providers.mistral_provider.MistralProvider.chat_completions_create", + "mistral", + "mistral-model", + ), + ( + "aisuite.providers.groq_provider.GroqProvider.chat_completions_create", + "groq", + "groq-model", + ), + ( + "aisuite.providers.aws_provider.AwsProvider.chat_completions_create", + "aws", + "claude-v3", + ), + ( + "aisuite.providers.azure_provider.AzureProvider.chat_completions_create", + "azure", + "azure-model", + ), + ( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create", + "anthropic", + "anthropic-model", + ), + ( + "aisuite.providers.google_provider.GoogleProvider.chat_completions_create", + "google", + "google-model", + ), + ( + "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create", + "fireworks", + "fireworks-model", + ), + ( + "aisuite.providers.nebius_provider.NebiusProvider.chat_completions_create", + "nebius", + "nebius-model", + ), + ( + "aisuite.providers.watsonx_provider.WatsonxProvider.chat_completions_create", + "watsonx", + "watsonx-model", + ), + ], +) +def test_client_chat_completions( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response client = Client() client.configure(provider_configs) messages = [ @@ -74,115 +112,53 @@ def test_client_chat_completions( {"role": "user", "content": "Who won the world series in 2020?"}, ] - # Test OpenAI model - open_ai_model = "openai" + ":" + "gpt-4o" - openai_response = client.chat.completions.create( - open_ai_model, messages=messages - ) - self.assertEqual(openai_response, "OpenAI Response") - mock_openai.assert_called_once() - - # Test AWS Bedrock model - bedrock_model = "aws" + ":" + "claude-v3" - bedrock_response = client.chat.completions.create( - bedrock_model, messages=messages - ) - self.assertEqual(bedrock_response, "AWS Bedrock Response") - mock_bedrock.assert_called_once() - - # Test Azure model - azure_model = "azure" + ":" + "azure-model" - azure_response = client.chat.completions.create(azure_model, messages=messages) - self.assertEqual(azure_response, "Azure Response") - mock_azure.assert_called_once() - - # Test Anthropic model - anthropic_model = "anthropic" + ":" + "anthropic-model" - anthropic_response = client.chat.completions.create( - anthropic_model, messages=messages - ) - self.assertEqual(anthropic_response, "Anthropic Response") - mock_anthropic.assert_called_once() - - # Test Groq model - groq_model = "groq" + ":" + "groq-model" - groq_response = client.chat.completions.create(groq_model, messages=messages) - self.assertEqual(groq_response, "Groq Response") - mock_groq.assert_called_once() - - # Test Mistral model - mistral_model = "mistral" + ":" + "mistral-model" - mistral_response = client.chat.completions.create( - mistral_model, messages=messages - ) - self.assertEqual(mistral_response, "Mistral Response") - mock_mistral.assert_called_once() - - # Test Google model - google_model = "google" + ":" + "google-model" - google_response = client.chat.completions.create( - google_model, messages=messages - ) - self.assertEqual(google_response, "Google Response") - mock_google.assert_called_once() - - # Test Fireworks model - fireworks_model = "fireworks" + ":" + "fireworks-model" - fireworks_response = client.chat.completions.create( - fireworks_model, messages=messages - ) - self.assertEqual(fireworks_response, "Fireworks Response") - mock_fireworks.assert_called_once() - - # Test that new instances of Completion are not created each time we make an inference call. - compl_instance = client.chat.completions - next_compl_instance = client.chat.completions - assert compl_instance is next_compl_instance - - def test_invalid_provider_in_client_config(self): - # Testing an invalid provider name in the configuration - invalid_provider_configs = { - "invalid_provider": {"api_key": "invalid_api_key"}, - } - - # Expect ValueError when initializing Client with invalid provider - with self.assertRaises(ValueError) as context: - client = Client(invalid_provider_configs) - - # Verify the error message - self.assertIn( - "Invalid provider key 'invalid_provider'. Supported providers: ", - str(context.exception), - ) - - @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") - def test_invalid_model_format_in_create(self, mock_openai): - # Valid provider configurations - provider_configs = { - "openai": {"api_key": "test_openai_api_key"}, - } - - # Initialize the client with valid provider - client = Client() - client.configure(provider_configs) + model_str = f"{provider}:{model}" + model_response = client.chat.completions.create(model_str, messages=messages) + assert model_response == expected_response - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a joke."}, - ] - # Invalid model format - invalid_model = "invalidmodel" +def test_invalid_provider_in_client_config(): + # Testing an invalid provider name in the configuration + invalid_provider_configs = { + "invalid_provider": {"api_key": "invalid_api_key"}, + } + + # Expect ValueError when initializing Client with invalid provider and verify message + with pytest.raises( + ValueError, + match=r"Invalid provider key 'invalid_provider'. Supported providers: ", + ): + _ = Client(invalid_provider_configs) - # Expect ValueError when calling create with invalid model format - with self.assertRaises(ValueError) as context: - client.chat.completions.create(invalid_model, messages=messages) - # Verify the error message - self.assertIn( - "Invalid model format. Expected 'provider:model'", str(context.exception) - ) +def test_invalid_model_format_in_create(monkeypatch): + from aisuite.providers.openai_provider import OpenaiProvider + + monkeypatch.setattr( + target=OpenaiProvider, + name="chat_completions_create", + value=Mock(), + ) + # Valid provider configurations + provider_configs = { + "openai": {"api_key": "test_openai_api_key"}, + } -if __name__ == "__main__": - unittest.main() + # Initialize the client with valid provider + client = Client() + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format and verify message + with pytest.raises( + ValueError, match=r"Invalid model format. Expected 'provider:model'" + ): + client.chat.completions.create(invalid_model, messages=messages) diff --git a/tests/client/test_prerelease.py b/tests/client/test_prerelease.py new file mode 100644 index 00000000..bb5f3285 --- /dev/null +++ b/tests/client/test_prerelease.py @@ -0,0 +1,74 @@ +# Run this test before releasing a new version. +# It will test all the models in the client. + +import pytest +import aisuite as ai +from typing import List, Dict +from dotenv import load_dotenv, find_dotenv + + +def setup_client() -> ai.Client: + """Initialize the AI client with environment variables.""" + load_dotenv(find_dotenv()) + return ai.Client() + + +def get_test_models() -> List[str]: + """Return a list of model identifiers to test.""" + return [ + "anthropic:claude-3-5-sonnet-20240620", + "aws:meta.llama3-1-8b-instruct-v1:0", + "huggingface:mistralai/Mistral-7B-Instruct-v0.3", + "groq:llama3-8b-8192", + "mistral:open-mistral-7b", + "openai:gpt-3.5-turbo", + "cohere:command-r-plus-08-2024", + ] + + +def get_test_messages() -> List[Dict[str, str]]: + """Return the test messages to send to each model.""" + return [ + { + "role": "system", + "content": "Respond in Pirate English. Always try to include the phrase - No rum No fun.", + }, + {"role": "user", "content": "Tell me a joke about Captain Jack Sparrow"}, + ] + + +@pytest.mark.integration +@pytest.mark.parametrize("model_id", get_test_models()) +def test_model_pirate_response(model_id: str): + """ + Test that each model responds appropriately to the pirate prompt. + + Args: + model_id: The provider:model identifier to test + """ + client = setup_client() + messages = get_test_messages() + + try: + response = client.chat.completions.create( + model=model_id, messages=messages, temperature=0.75 + ) + + content = response.choices[0].message.content.lower() + + # Check if either version of the required phrase is present + assert any( + phrase in content for phrase in ["no rum no fun", "no rum, no fun"] + ), f"Model {model_id} did not include required phrase 'No rum No fun'" + + assert len(content) > 0, f"Model {model_id} returned empty response" + assert isinstance( + content, str + ), f"Model {model_id} returned non-string response" + + except Exception as e: + pytest.fail(f"Error testing model {model_id}: {str(e)}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/providers/test_cohere_provider.py b/tests/providers/test_cohere_provider.py new file mode 100644 index 00000000..d7e10486 --- /dev/null +++ b/tests/providers/test_cohere_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.cohere_provider import CohereProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("CO_API_KEY", "test-api-key") + + +def test_cohere_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = CohereProvider() + mock_response = MagicMock() + mock_response.message = MagicMock() + mock_response.message.content = [MagicMock()] + mock_response.message.content[0].text = response_text_content + + with patch.object( + provider.client, + "chat", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content diff --git a/tests/providers/test_deepseek_provider.py b/tests/providers/test_deepseek_provider.py new file mode 100644 index 00000000..1ab6f1c1 --- /dev/null +++ b/tests/providers/test_deepseek_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.deepseek_provider import DeepseekProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("DEEPSEEK_API_KEY", "test-api-key") + + +def test_groq_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = DeepseekProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content diff --git a/tests/providers/test_nebius_provider.py b/tests/providers/test_nebius_provider.py new file mode 100644 index 00000000..8e969ea5 --- /dev/null +++ b/tests/providers/test_nebius_provider.py @@ -0,0 +1,45 @@ +import pytest +from unittest.mock import patch, MagicMock + +from aisuite.providers.nebius_provider import NebiusProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("NEBIUS_API_KEY", "test-api-key") + + +def test_nebius_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = NebiusProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content diff --git a/tests/providers/test_sambanova_provider.py b/tests/providers/test_sambanova_provider.py new file mode 100644 index 00000000..b5c649ec --- /dev/null +++ b/tests/providers/test_sambanova_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.sambanova_provider import SambanovaProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("SAMBANOVA_API_KEY", "test-api-key") + + +def test_sambanova_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = SambanovaProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content diff --git a/tests/providers/test_watsonx_provider.py b/tests/providers/test_watsonx_provider.py new file mode 100644 index 00000000..8e7123a7 --- /dev/null +++ b/tests/providers/test_watsonx_provider.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock, patch + +import pytest +from ibm_watsonx_ai.metanames import GenChatParamsMetaNames as GenChatParams + +from aisuite.providers.watsonx_provider import WatsonxProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("WATSONX_SERVICE_URL", "https://watsonx-service-url.com") + monkeypatch.setenv("WATSONX_API_KEY", "test-api-key") + monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id") + + +def test_watsonx_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.7 + response_text_content = "mocked-text-response-from-model" + + provider = WatsonxProvider() + mock_response = {"choices": [{"message": {"content": response_text_content}}]} + + with patch( + "aisuite.providers.watsonx_provider.ModelInference" + ) as mock_model_inference: + mock_model = MagicMock() + mock_model_inference.return_value = mock_model + mock_model.chat.return_value = mock_response + + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + # Assert that ModelInference was called with correct arguments. + mock_model_inference.assert_called_once() + args, kwargs = mock_model_inference.call_args + assert kwargs["model_id"] == selected_model + assert kwargs["project_id"] == provider.project_id + + # Assert that the credentials have the correct API key and service URL. + credentials = kwargs["credentials"] + assert credentials.api_key == provider.api_key + assert credentials.url == provider.service_url + + # Assert that chat was called with correct history and params + mock_model.chat.assert_called_once_with( + messages=message_history, + params={GenChatParams.TEMPERATURE: chosen_temperature}, + ) + + assert response.choices[0].message.content == response_text_content