From e12bd49d02fd8f11bb67365a8e4f66c192c0a51e Mon Sep 17 00:00:00 2001 From: Yuan-Man <68322456+Yuan-ManX@users.noreply.github.com> Date: Tue, 26 Nov 2024 17:11:41 +0800 Subject: [PATCH 01/40] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bd7df95e..eaabc943 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # aisuite +[![PyPI](https://img.shields.io/pypi/v/aisuite)](https://pypi.org/project/aisuite/) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) Simple, unified interface to multiple Generative AI providers. From 03366cbf99d63895dea281ca183202ebf27c8407 Mon Sep 17 00:00:00 2001 From: BRlin-o Date: Wed, 27 Nov 2024 01:11:39 +0800 Subject: [PATCH 02/40] Unified environment variable name is AWS_REGION to improve consistency --- aisuite/providers/aws_provider.py | 2 +- guides/aws.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aisuite/providers/aws_provider.py b/aisuite/providers/aws_provider.py index 10f48afe..f44f5150 100644 --- a/aisuite/providers/aws_provider.py +++ b/aisuite/providers/aws_provider.py @@ -34,7 +34,7 @@ def __init__(self, **config): """ self.region_name = config.get( - "region_name", os.getenv("AWS_REGION_NAME", "us-west-2") + "region_name", os.getenv("AWS_REGION", "us-west-2") ) self.client = boto3.client("bedrock-runtime", region_name=self.region_name) self.inference_parameters = [ diff --git a/guides/aws.md b/guides/aws.md index a01d6eb6..4433fdfb 100644 --- a/guides/aws.md +++ b/guides/aws.md @@ -23,9 +23,9 @@ Once that has been enabled set your Access Key and Secret in the env variables: ```shell export AWS_ACCESS_KEY="your-access-key" export AWS_SECRET_KEY="your-secret-key" -export AWS_REGION_NAME="region-name" +export AWS_REGION="region-name" ``` -*Note: AWS_REGION_NAME is optional, a default of `us-west-2` has been set for easy of use* +*Note: AWS_REGION is optional, a default of `us-west-2` has been set for easy of use* ## Create a Chat Completion From 7d4e093113de3d6f71d9fa8f1308f93b989b5b9f Mon Sep 17 00:00:00 2001 From: foxty Date: Thu, 28 Nov 2024 17:29:42 +0800 Subject: [PATCH 03/40] doc: fix broken link to contributing guide. --- guides/anthropic.md | 2 +- guides/aws.md | 2 +- guides/azure.md | 2 ++ guides/google.md | 2 +- guides/huggingface.md | 2 ++ guides/openai.md | 2 +- 6 files changed, 8 insertions(+), 4 deletions(-) diff --git a/guides/anthropic.md b/guides/anthropic.md index 8d70cb5e..0f674c17 100644 --- a/guides/anthropic.md +++ b/guides/anthropic.md @@ -44,4 +44,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/aws.md b/guides/aws.md index a01d6eb6..35867a69 100644 --- a/guides/aws.md +++ b/guides/aws.md @@ -63,7 +63,7 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). diff --git a/guides/azure.md b/guides/azure.md index e9a71fe0..8246b7ad 100644 --- a/guides/azure.md +++ b/guides/azure.md @@ -56,3 +56,5 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` + +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). \ No newline at end of file diff --git a/guides/google.md b/guides/google.md index eb351bd0..ffd21948 100644 --- a/guides/google.md +++ b/guides/google.md @@ -89,4 +89,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). \ No newline at end of file +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). \ No newline at end of file diff --git a/guides/huggingface.md b/guides/huggingface.md index 11bd9297..e6816fb5 100644 --- a/guides/huggingface.md +++ b/guides/huggingface.md @@ -53,3 +53,5 @@ print(response.choices[0].message.content) - Ensure that the `model` variable matches the identifier of your model as seen in the Hugging Face Model Hub. - If you encounter any rate limits or API access restrictions, you may have to upgrade your Hugging Face plan to enable higher usage limits. """ + +Happy coding! If you would like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). \ No newline at end of file diff --git a/guides/openai.md b/guides/openai.md index 6dc9ce97..ab297490 100644 --- a/guides/openai.md +++ b/guides/openai.md @@ -41,4 +41,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). +Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). From 5cd518eb81b6c90ac2a4a37034d0205036722692 Mon Sep 17 00:00:00 2001 From: Dax Patel Date: Thu, 28 Nov 2024 16:36:54 -0500 Subject: [PATCH 04/40] Add Watsonx provider with tests using the python SDK --- .env.sample | 6 + README.md | 2 +- aisuite/providers/watsonx_provider.py | 33 ++++ guides/watsonx.md | 83 +++++++++ poetry.lock | 221 ++++++++++++++++------- pyproject.toml | 5 +- tests/client/test_client.py | 17 ++ tests/providers/test_watsonx_provider.py | 63 +++++++ 8 files changed, 366 insertions(+), 64 deletions(-) create mode 100644 aisuite/providers/watsonx_provider.py create mode 100644 guides/watsonx.md create mode 100644 tests/providers/test_watsonx_provider.py diff --git a/.env.sample b/.env.sample index 00826f4b..e7e9618d 100644 --- a/.env.sample +++ b/.env.sample @@ -25,3 +25,9 @@ FIREWORKS_API_KEY= # Together AI TOGETHER_API_KEY= + +# WatsonX +WATSONX_SERVICE_URL= +WATSONX_API_KEY= +WATSONX_PROJECT_ID= + diff --git a/README.md b/README.md index bd7df95e..704c96b0 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Simple, unified interface to multiple Generative AI providers. `aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future. Currently supported providers are - -OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace and Ollama. +OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace, Ollama and Watsonx. To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider. ## Installation diff --git a/aisuite/providers/watsonx_provider.py b/aisuite/providers/watsonx_provider.py new file mode 100644 index 00000000..0900479b --- /dev/null +++ b/aisuite/providers/watsonx_provider.py @@ -0,0 +1,33 @@ +from aisuite.provider import Provider +import os +from ibm_watsonx_ai import Credentials +from ibm_watsonx_ai.foundation_models import ModelInference +from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams + +DEFAULT_TEMPERATURE = 0.7 + + +class WatsonxProvider(Provider): + def __init__(self, **config): + self.service_url = config.get("service_url") or os.getenv("WATSONX_SERVICE_URL") + self.api_key = config.get("api_key") or os.getenv("WATSONX_API_KEY") + self.project_id = config.get("project_id") or os.getenv("WATSONX_PROJECT_ID") + + if not self.service_url or not self.api_key or not self.project_id: + raise EnvironmentError( + "Missing one or more required WatsonX environment variables: " + "WATSONX_SERVICE_URL, WATSONX_API_KEY, WATSONX_PROJECT_ID. " + "Please refer to the setup guide: /guides/watsonx.md." + ) + + def chat_completions_create(self, model, messages, **kwargs): + model = ModelInference( + model_id=model, + params={ + GenParams.TEMPERATURE: kwargs.get("temperature", DEFAULT_TEMPERATURE), + }, + credentials=Credentials(api_key=self.api_key, url=self.service_url), + project_id=self.project_id, + ) + + return model.chat(prompt=messages, **kwargs) diff --git a/guides/watsonx.md b/guides/watsonx.md new file mode 100644 index 00000000..3353666d --- /dev/null +++ b/guides/watsonx.md @@ -0,0 +1,83 @@ +# Watsonx with `aisuite` + +A a step-by-step guide to set up Watsonx with the `aisuite` library, enabling you to use IBM Watsonx's powerful AI models for various tasks. + +## Setup Instructions + +### Step 1: Create a Watsonx Account + +1. Visit [IBM Watsonx](https://www.ibm.com/watsonx). +2. Sign up for a new account or log in with your existing IBM credentials. +3. Once logged in, navigate to the **Watsonx Dashboard**. + +--- + +### Step 2: Obtain API Credentials + +1. **Generate an API Key**: + - Go to the **API Keys** section in your Watsonx account settings. + - Click on **Create API Key**. + - Provide a name for your API key (e.g., `MyWatsonxKey`). + - Click **Generate**, then download or copy the API key. **Keep this key secure!** + +2. **Locate the Service URL**: + - Go to the **Endpoints** section in the Watsonx dashboard. + - Find the URL corresponding to your service and note it. This is your `WATSONX_SERVICE_URL`. + +3. **Get the Project ID**: + - Navigate to the **Projects** tab in the dashboard. + - Select the project you want to use. + - Copy the **Project ID**. This will serve as your `WATSONX_PROJECT_ID`. + +--- + +### Step 3: Set Environment Variables + +To simplify authentication, set the following environment variables: + +Run the following commands in your terminal: + +```bash +export WATSONX_API_KEY="your-watsonx-api-key" +export WATSONX_SERVICE_URL="your-watsonx-service-url" +export WATSONX_PROJECT_ID="your-watsonx-project-id" +``` + + +## Create a Chat Completion + +Install the `ibm-watsonx-ai` Python client: + +Example with pip: + +```shell +pip install ibm-watsonx-ai +``` + +Example with poetry: + +```shell +poetry add ibm-watsonx-ai +``` + +In your code: + +```python +import aisuite as ai +client = ai.Client() + +provider = "watsonx" +model_id = "meta-llama/llama-3-70b-instruct" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 27aa1cc1..ce54777c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1361,12 +1361,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" @@ -1926,6 +1926,80 @@ files = [ [package.dependencies] pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} +[[package]] +name = "ibm-cos-sdk" +version = "2.13.6" +description = "IBM SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ibm-cos-sdk-2.13.6.tar.gz", hash = "sha256:171cf2ae4ab662a4b8ab58dcf4ac994b0577d6c92d78490295fd7704a83978f6"}, +] + +[package.dependencies] +ibm-cos-sdk-core = "2.13.6" +ibm-cos-sdk-s3transfer = "2.13.6" +jmespath = ">=0.10.0,<=1.0.1" + +[[package]] +name = "ibm-cos-sdk-core" +version = "2.13.6" +description = "Low-level, data-driven core of IBM SDK for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "ibm-cos-sdk-core-2.13.6.tar.gz", hash = "sha256:dd41fb789eeb65546501afabcd50e78846ab4513b6ad4042e410b6a14ff88413"}, +] + +[package.dependencies] +jmespath = ">=0.10.0,<=1.0.1" +python-dateutil = ">=2.9.0,<3.0.0" +requests = ">=2.32.0,<2.32.3" +urllib3 = ">=1.26.18,<3" + +[[package]] +name = "ibm-cos-sdk-s3transfer" +version = "2.13.6" +description = "IBM S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ibm-cos-sdk-s3transfer-2.13.6.tar.gz", hash = "sha256:e0acce6f380c47d11e07c6765b684b4ababbf5c66cc0503bc246469a1e2b9790"}, +] + +[package.dependencies] +ibm-cos-sdk-core = "2.13.6" + +[[package]] +name = "ibm-watsonx-ai" +version = "1.1.16" +description = "IBM watsonx.ai API Client" +optional = false +python-versions = ">=3.10" +files = [ + {file = "ibm_watsonx_ai-1.1.16-py3-none-any.whl", hash = "sha256:c703adda2588c85606f74c230afe3ce31202815de369301df19f14ce21bd093a"}, + {file = "ibm_watsonx_ai-1.1.16.tar.gz", hash = "sha256:ab79ed5dedd57fd574c5c6c5ceca50a89b8423562646255c910ed74a8d8811a5"}, +] + +[package.dependencies] +certifi = "*" +httpx = "*" +ibm-cos-sdk = ">=2.12.0,<2.14.0" +importlib-metadata = "*" +lomond = "*" +packaging = "*" +pandas = ">=0.24.2,<2.2.0" +requests = "*" +tabulate = "*" +urllib3 = "*" + +[package.extras] +fl-crypto = ["pyhelayers (==1.5.0.3)"] +fl-crypto-rt24-1 = ["pyhelayers (==1.5.3.1)"] +fl-rt23-1-py3-10 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.1.1)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.23.5)", "pandas (==1.5.3)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.1.1)", "scipy (==1.10.1)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.12.0)", "torch (==2.0.1)", "websockets (==10.1)"] +fl-rt24-1-py3-11 = ["GPUtil", "cryptography (==42.0.5)", "ddsketch (==2.0.4)", "diffprivlib (==0.5.1)", "environs (==9.5.0)", "gym", "image (==1.5.33)", "joblib (==1.3.2)", "lz4", "msgpack (==1.0.7)", "msgpack-numpy (==0.4.8)", "numcompress (==0.1.2)", "numpy (==1.26.4)", "pandas (==2.1.4)", "parse (==1.19.0)", "pathlib2 (==2.3.6)", "protobuf (==4.22.1)", "psutil", "pyYAML (==6.0.1)", "pytest (==6.2.5)", "requests (==2.32.3)", "scikit-learn (==1.3.0)", "scipy (==1.11.4)", "setproctitle", "skops (==0.9.0)", "skorch (==0.12.0)", "tabulate (==0.8.9)", "tensorflow (==2.14.1)", "torch (==2.1.2)", "websockets (==10.1)"] +rag = ["beautifulsoup4 (==4.12.3)", "grpcio (>=1.60.0)", "langchain (>=0.2.15,<0.3)", "langchain-chroma (==0.1.1)", "langchain-community (>=0.2.4,<0.3)", "langchain-core (>=0.2.37,<0.3)", "langchain-elasticsearch (==0.2.2)", "langchain-ibm", "langchain-milvus (==0.1.1)", "markdown (==3.4.1)", "pypdf (==4.2.0)", "python-docx (==1.1.2)"] + [[package]] name = "identify" version = "2.6.0" @@ -2517,6 +2591,20 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "lomond" +version = "0.3.3" +description = "Websocket Client Library" +optional = false +python-versions = "*" +files = [ + {file = "lomond-0.3.3-py2.py3-none-any.whl", hash = "sha256:df1dd4dd7b802a12b71907ab1abb08b8ce9950195311207579379eb3b1553de7"}, + {file = "lomond-0.3.3.tar.gz", hash = "sha256:427936596b144b4ec387ead99aac1560b77c8a78107d3d49415d3abbe79acbd3"}, +] + +[package.dependencies] +six = ">=1.10.0" + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -3612,76 +3700,71 @@ files = [ [[package]] name = "pandas" -version = "2.2.2" +version = "2.1.4" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, - {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, - {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, - {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, - {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, - {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, + {file = "pandas-2.1.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bdec823dc6ec53f7a6339a0e34c68b144a7a1fd28d80c260534c39c62c5bf8c9"}, + {file = "pandas-2.1.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:294d96cfaf28d688f30c918a765ea2ae2e0e71d3536754f4b6de0ea4a496d034"}, + {file = "pandas-2.1.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b728fb8deba8905b319f96447a27033969f3ea1fea09d07d296c9030ab2ed1d"}, + {file = "pandas-2.1.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00028e6737c594feac3c2df15636d73ace46b8314d236100b57ed7e4b9ebe8d9"}, + {file = "pandas-2.1.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:426dc0f1b187523c4db06f96fb5c8d1a845e259c99bda74f7de97bd8a3bb3139"}, + {file = "pandas-2.1.4-cp310-cp310-win_amd64.whl", hash = "sha256:f237e6ca6421265643608813ce9793610ad09b40154a3344a088159590469e46"}, + {file = "pandas-2.1.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b7d852d16c270e4331f6f59b3e9aa23f935f5c4b0ed2d0bc77637a8890a5d092"}, + {file = "pandas-2.1.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bd7d5f2f54f78164b3d7a40f33bf79a74cdee72c31affec86bfcabe7e0789821"}, + {file = "pandas-2.1.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0aa6e92e639da0d6e2017d9ccff563222f4eb31e4b2c3cf32a2a392fc3103c0d"}, + {file = "pandas-2.1.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d797591b6846b9db79e65dc2d0d48e61f7db8d10b2a9480b4e3faaddc421a171"}, + {file = "pandas-2.1.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d2d3e7b00f703aea3945995ee63375c61b2e6aa5aa7871c5d622870e5e137623"}, + {file = "pandas-2.1.4-cp311-cp311-win_amd64.whl", hash = "sha256:dc9bf7ade01143cddc0074aa6995edd05323974e6e40d9dbde081021ded8510e"}, + {file = "pandas-2.1.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:482d5076e1791777e1571f2e2d789e940dedd927325cc3cb6d0800c6304082f6"}, + {file = "pandas-2.1.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8a706cfe7955c4ca59af8c7a0517370eafbd98593155b48f10f9811da440248b"}, + {file = "pandas-2.1.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0513a132a15977b4a5b89aabd304647919bc2169eac4c8536afb29c07c23540"}, + {file = "pandas-2.1.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9f17f2b6fc076b2a0078862547595d66244db0f41bf79fc5f64a5c4d635bead"}, + {file = "pandas-2.1.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:45d63d2a9b1b37fa6c84a68ba2422dc9ed018bdaa668c7f47566a01188ceeec1"}, + {file = "pandas-2.1.4-cp312-cp312-win_amd64.whl", hash = "sha256:f69b0c9bb174a2342818d3e2778584e18c740d56857fc5cdb944ec8bbe4082cf"}, + {file = "pandas-2.1.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3f06bda01a143020bad20f7a85dd5f4a1600112145f126bc9e3e42077c24ef34"}, + {file = "pandas-2.1.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab5796839eb1fd62a39eec2916d3e979ec3130509930fea17fe6f81e18108f6a"}, + {file = "pandas-2.1.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edbaf9e8d3a63a9276d707b4d25930a262341bca9874fcb22eff5e3da5394732"}, + {file = "pandas-2.1.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ebfd771110b50055712b3b711b51bee5d50135429364d0498e1213a7adc2be8"}, + {file = "pandas-2.1.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8ea107e0be2aba1da619cc6ba3f999b2bfc9669a83554b1904ce3dd9507f0860"}, + {file = "pandas-2.1.4-cp39-cp39-win_amd64.whl", hash = "sha256:d65148b14788b3758daf57bf42725caa536575da2b64df9964c563b015230984"}, + {file = "pandas-2.1.4.tar.gz", hash = "sha256:fcb68203c833cc735321512e13861358079a96c174a61f5116a1de89c58c0ef7"}, ] [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" -tzdata = ">=2022.7" +tzdata = ">=2022.1" [package.extras] -all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"] -aws = ["s3fs (>=2022.11.0)"] -clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"] -compression = ["zstandard (>=0.19.0)"] -computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"] +all = ["PyQt5 (>=5.15.6)", "SQLAlchemy (>=1.4.36)", "beautifulsoup4 (>=4.11.1)", "bottleneck (>=1.3.4)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=0.8.1)", "fsspec (>=2022.05.0)", "gcsfs (>=2022.05.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.8.0)", "matplotlib (>=3.6.1)", "numba (>=0.55.2)", "numexpr (>=2.8.0)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pandas-gbq (>=0.17.5)", "psycopg2 (>=2.9.3)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "pyreadstat (>=1.1.5)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "pyxlsb (>=1.0.9)", "qtpy (>=2.2.0)", "s3fs (>=2022.05.0)", "scipy (>=1.8.1)", "tables (>=3.7.0)", "tabulate (>=0.8.10)", "xarray (>=2022.03.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)", "zstandard (>=0.17.0)"] +aws = ["s3fs (>=2022.05.0)"] +clipboard = ["PyQt5 (>=5.15.6)", "qtpy (>=2.2.0)"] +compression = ["zstandard (>=0.17.0)"] +computation = ["scipy (>=1.8.1)", "xarray (>=2022.03.0)"] consortium-standard = ["dataframe-api-compat (>=0.1.7)"] -excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"] -feather = ["pyarrow (>=10.0.1)"] -fss = ["fsspec (>=2022.11.0)"] -gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"] -hdf5 = ["tables (>=3.8.0)"] -html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"] -mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"] -output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"] -parquet = ["pyarrow (>=10.0.1)"] -performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"] -plot = ["matplotlib (>=3.6.3)"] -postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"] -pyarrow = ["pyarrow (>=10.0.1)"] -spss = ["pyreadstat (>=1.2.0)"] -sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"] +excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.10)", "pyxlsb (>=1.0.9)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.3)"] +feather = ["pyarrow (>=7.0.0)"] +fss = ["fsspec (>=2022.05.0)"] +gcp = ["gcsfs (>=2022.05.0)", "pandas-gbq (>=0.17.5)"] +hdf5 = ["tables (>=3.7.0)"] +html = ["beautifulsoup4 (>=4.11.1)", "html5lib (>=1.1)", "lxml (>=4.8.0)"] +mysql = ["SQLAlchemy (>=1.4.36)", "pymysql (>=1.0.2)"] +output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.8.10)"] +parquet = ["pyarrow (>=7.0.0)"] +performance = ["bottleneck (>=1.3.4)", "numba (>=0.55.2)", "numexpr (>=2.8.0)"] +plot = ["matplotlib (>=3.6.1)"] +postgresql = ["SQLAlchemy (>=1.4.36)", "psycopg2 (>=2.9.3)"] +spss = ["pyreadstat (>=1.1.5)"] +sql-other = ["SQLAlchemy (>=1.4.36)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] -xml = ["lxml (>=4.9.2)"] +xml = ["lxml (>=4.8.0)"] [[package]] name = "pandocfilters" @@ -4135,8 +4218,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.6.1", markers = "python_version < \"3.13\""}, {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, ] [package.extras] @@ -4678,13 +4761,13 @@ files = [ [[package]] name = "requests" -version = "2.32.3" +version = "2.32.2" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" files = [ - {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, - {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] @@ -5322,6 +5405,20 @@ mpmath = ">=1.1.0,<1.4" [package.extras] dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "tenacity" version = "8.5.0" @@ -6492,4 +6589,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e455ddfd148ffae77e7b9ac196792b38c7cc545c0866cca2aaae227f4a4201df" +content-hash = "b66e696e459cffd4f1eadb78693b93d048a2635be88344740d9c11f564d946f5" diff --git a/pyproject.toml b/pyproject.toml index 8ae9295b..3b79df44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ vertexai = { version = "^1.63.0", optional = true } groq = { version = "^0.9.0", optional = true } mistralai = { version = "^1.0.3", optional = true } openai = { version = "^1.35.8", optional = true } +ibm-watsonx-ai = { version = "^1.1.16", optional = true } # Optional dependencies for different providers [tool.poetry.extras] @@ -25,7 +26,8 @@ huggingface = [] mistral = ["mistralai"] ollama = [] openai = ["openai"] -all = ["anthropic", "aws", "google", "groq", "mistral", "openai"] # To install all providers +watsonx = ["ibm-watsonx-ai"] +all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "watsonx"] # To install all providers [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" @@ -44,6 +46,7 @@ chromadb = "^0.5.4" sentence-transformers = "^3.0.1" datasets = "^2.20.0" vertexai = "^1.63.0" +ibm-watsonx-ai = "^1.1.16" [build-system] requires = ["poetry-core"] diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e1949ac..e0229259 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -4,6 +4,7 @@ class TestClient(unittest.TestCase): + @patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create") @patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create") @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") @@ -16,6 +17,7 @@ class TestClient(unittest.TestCase): @patch( "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create" ) + @patch("aisuite.providers.watsonx_provider.WatsonxProvider.chat_completions_create") def test_client_chat_completions( self, mock_fireworks, @@ -26,6 +28,7 @@ def test_client_chat_completions( mock_openai, mock_groq, mock_mistral, + mock_watsonx, ): # Mock responses from providers mock_openai.return_value = "OpenAI Response" @@ -36,6 +39,7 @@ def test_client_chat_completions( mock_mistral.return_value = "Mistral Response" mock_google.return_value = "Google Response" mock_fireworks.return_value = "Fireworks Response" + mock_watsonx.return_value = "Watsonx Response" # Provider configurations provider_configs = { @@ -64,6 +68,11 @@ def test_client_chat_completions( "fireworks": { "api_key": "fireworks-api-key", }, + "watsonx": { + "service_url": "https://watsonx-service-url.com", + "api_key": "watsonx-api-key", + "project_id": "watsonx-project-id", + }, } # Initialize the client @@ -134,6 +143,14 @@ def test_client_chat_completions( self.assertEqual(fireworks_response, "Fireworks Response") mock_fireworks.assert_called_once() + # Test Watsonx model + watsonx_model = "watsonx" + ":" + "watsonx-model" + watsonx_response = client.chat.completions.create( + watsonx_model, messages=messages + ) + self.assertEqual(watsonx_response, "Watsonx Response") + mock_watsonx.assert_called_once() + # Test that new instances of Completion are not created each time we make an inference call. compl_instance = client.chat.completions next_compl_instance = client.chat.completions diff --git a/tests/providers/test_watsonx_provider.py b/tests/providers/test_watsonx_provider.py new file mode 100644 index 00000000..4fc22555 --- /dev/null +++ b/tests/providers/test_watsonx_provider.py @@ -0,0 +1,63 @@ +from unittest.mock import MagicMock, patch + +import pytest +from ibm_watsonx_ai import Credentials +from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams + +from aisuite.providers.watsonx_provider import WatsonxProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("WATSONX_SERVICE_URL", "https://watsonx-service-url.com") + monkeypatch.setenv("WATSONX_API_KEY", "test-api-key") + monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id") + + +def test_watsonx_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.7 + response_text_content = "mocked-text-response-from-model" + + provider = WatsonxProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch( + "aisuite.providers.watsonx_provider.ModelInference" + ) as mock_model_inference: + mock_model = MagicMock() + mock_model_inference.return_value = mock_model + mock_model.chat.return_value = mock_response + + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + # Assert that ModelInference was called with correct arguments. + mock_model_inference.assert_called_once() + args, kwargs = mock_model_inference.call_args + assert kwargs["model_id"] == selected_model + assert kwargs["params"] == {GenParams.TEMPERATURE: chosen_temperature} + + # Assert that the credentials have the correct API key and service URL. + credentials = kwargs["credentials"] + assert credentials.api_key == provider.api_key + assert credentials.url == provider.service_url + + # Assert that chat was called with correct history and temperature. + mock_model.chat.assert_called_once_with( + prompt=message_history, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content From e96535e6b6e22bf543f93882e2f64f4695c45890 Mon Sep 17 00:00:00 2001 From: chris-stokes Date: Sat, 30 Nov 2024 22:02:40 +0000 Subject: [PATCH 05/40] Create missing test group, Add test coverage report --- .gitignore | 6 +++ poetry.lock | 99 +++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 15 +++++++- 3 files changed, 117 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 5b651c66..718eebba 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,9 @@ __pycache__/ env/ .env .google-adc + +# Testing +.coverage + +# pyenv +.python-version \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 27aa1cc1..4f1e3c45 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohttp" @@ -883,6 +883,83 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "coverage" +version = "7.6.8" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "coverage-7.6.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b39e6011cd06822eb964d038d5dff5da5d98652b81f5ecd439277b32361a3a50"}, + {file = "coverage-7.6.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:63c19702db10ad79151a059d2d6336fe0c470f2e18d0d4d1a57f7f9713875dcf"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3985b9be361d8fb6b2d1adc9924d01dec575a1d7453a14cccd73225cb79243ee"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:644ec81edec0f4ad17d51c838a7d01e42811054543b76d4ba2c5d6af741ce2a6"}, + {file = "coverage-7.6.8-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f188a2402f8359cf0c4b1fe89eea40dc13b52e7b4fd4812450da9fcd210181d"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e19122296822deafce89a0c5e8685704c067ae65d45e79718c92df7b3ec3d331"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:13618bed0c38acc418896005732e565b317aa9e98d855a0e9f211a7ffc2d6638"}, + {file = "coverage-7.6.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:193e3bffca48ad74b8c764fb4492dd875038a2f9925530cb094db92bb5e47bed"}, + {file = "coverage-7.6.8-cp310-cp310-win32.whl", hash = "sha256:3988665ee376abce49613701336544041f2117de7b7fbfe91b93d8ff8b151c8e"}, + {file = "coverage-7.6.8-cp310-cp310-win_amd64.whl", hash = "sha256:f56f49b2553d7dd85fd86e029515a221e5c1f8cb3d9c38b470bc38bde7b8445a"}, + {file = "coverage-7.6.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:86cffe9c6dfcfe22e28027069725c7f57f4b868a3f86e81d1c62462764dc46d4"}, + {file = "coverage-7.6.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d82ab6816c3277dc962cfcdc85b1efa0e5f50fb2c449432deaf2398a2928ab94"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13690e923a3932e4fad4c0ebfb9cb5988e03d9dcb4c5150b5fcbf58fd8bddfc4"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4be32da0c3827ac9132bb488d331cb32e8d9638dd41a0557c5569d57cf22c9c1"}, + {file = "coverage-7.6.8-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e6c85bbdc809383b509d732b06419fb4544dca29ebe18480379633623baafb"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:768939f7c4353c0fac2f7c37897e10b1414b571fd85dd9fc49e6a87e37a2e0d8"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e44961e36cb13c495806d4cac67640ac2866cb99044e210895b506c26ee63d3a"}, + {file = "coverage-7.6.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3ea8bb1ab9558374c0ab591783808511d135a833c3ca64a18ec927f20c4030f0"}, + {file = "coverage-7.6.8-cp311-cp311-win32.whl", hash = "sha256:629a1ba2115dce8bf75a5cce9f2486ae483cb89c0145795603d6554bdc83e801"}, + {file = "coverage-7.6.8-cp311-cp311-win_amd64.whl", hash = "sha256:fb9fc32399dca861584d96eccd6c980b69bbcd7c228d06fb74fe53e007aa8ef9"}, + {file = "coverage-7.6.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e683e6ecc587643f8cde8f5da6768e9d165cd31edf39ee90ed7034f9ca0eefee"}, + {file = "coverage-7.6.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1defe91d41ce1bd44b40fabf071e6a01a5aa14de4a31b986aa9dfd1b3e3e414a"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7ad66e8e50225ebf4236368cc43c37f59d5e6728f15f6e258c8639fa0dd8e6d"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fe47da3e4fda5f1abb5709c156eca207eacf8007304ce3019eb001e7a7204cb"}, + {file = "coverage-7.6.8-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:202a2d645c5a46b84992f55b0a3affe4f0ba6b4c611abec32ee88358db4bb649"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4674f0daa1823c295845b6a740d98a840d7a1c11df00d1fd62614545c1583787"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:74610105ebd6f33d7c10f8907afed696e79c59e3043c5f20eaa3a46fddf33b4c"}, + {file = "coverage-7.6.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37cda8712145917105e07aab96388ae76e787270ec04bcb9d5cc786d7cbb8443"}, + {file = "coverage-7.6.8-cp312-cp312-win32.whl", hash = "sha256:9e89d5c8509fbd6c03d0dd1972925b22f50db0792ce06324ba069f10787429ad"}, + {file = "coverage-7.6.8-cp312-cp312-win_amd64.whl", hash = "sha256:379c111d3558272a2cae3d8e57e6b6e6f4fe652905692d54bad5ea0ca37c5ad4"}, + {file = "coverage-7.6.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b0c69f4f724c64dfbfe79f5dfb503b42fe6127b8d479b2677f2b227478db2eb"}, + {file = "coverage-7.6.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c15b32a7aca8038ed7644f854bf17b663bc38e1671b5d6f43f9a2b2bd0c46f63"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63068a11171e4276f6ece913bde059e77c713b48c3a848814a6537f35afb8365"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6f4548c5ead23ad13fb7a2c8ea541357474ec13c2b736feb02e19a3085fac002"}, + {file = "coverage-7.6.8-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b4b4299dd0d2c67caaaf286d58aef5e75b125b95615dda4542561a5a566a1e3"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c9ebfb2507751f7196995142f057d1324afdab56db1d9743aab7f50289abd022"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c1b4474beee02ede1eef86c25ad4600a424fe36cff01a6103cb4533c6bf0169e"}, + {file = "coverage-7.6.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d9fd2547e6decdbf985d579cf3fc78e4c1d662b9b0ff7cc7862baaab71c9cc5b"}, + {file = "coverage-7.6.8-cp313-cp313-win32.whl", hash = "sha256:8aae5aea53cbfe024919715eca696b1a3201886ce83790537d1c3668459c7146"}, + {file = "coverage-7.6.8-cp313-cp313-win_amd64.whl", hash = "sha256:ae270e79f7e169ccfe23284ff5ea2d52a6f401dc01b337efb54b3783e2ce3f28"}, + {file = "coverage-7.6.8-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:de38add67a0af869b0d79c525d3e4588ac1ffa92f39116dbe0ed9753f26eba7d"}, + {file = "coverage-7.6.8-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b07c25d52b1c16ce5de088046cd2432b30f9ad5e224ff17c8f496d9cb7d1d451"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62a66ff235e4c2e37ed3b6104d8b478d767ff73838d1222132a7a026aa548764"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09b9f848b28081e7b975a3626e9081574a7b9196cde26604540582da60235fdf"}, + {file = "coverage-7.6.8-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:093896e530c38c8e9c996901858ac63f3d4171268db2c9c8b373a228f459bbc5"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9a7b8ac36fd688c8361cbc7bf1cb5866977ece6e0b17c34aa0df58bda4fa18a4"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:38c51297b35b3ed91670e1e4efb702b790002e3245a28c76e627478aa3c10d83"}, + {file = "coverage-7.6.8-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2e4e0f60cb4bd7396108823548e82fdab72d4d8a65e58e2c19bbbc2f1e2bfa4b"}, + {file = "coverage-7.6.8-cp313-cp313t-win32.whl", hash = "sha256:6535d996f6537ecb298b4e287a855f37deaf64ff007162ec0afb9ab8ba3b8b71"}, + {file = "coverage-7.6.8-cp313-cp313t-win_amd64.whl", hash = "sha256:c79c0685f142ca53256722a384540832420dff4ab15fec1863d7e5bc8691bdcc"}, + {file = "coverage-7.6.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ac47fa29d8d41059ea3df65bd3ade92f97ee4910ed638e87075b8e8ce69599e"}, + {file = "coverage-7.6.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:24eda3a24a38157eee639ca9afe45eefa8d2420d49468819ac5f88b10de84f4c"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4c81ed2820b9023a9a90717020315e63b17b18c274a332e3b6437d7ff70abe0"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd55f8fc8fa494958772a2a7302b0354ab16e0b9272b3c3d83cdb5bec5bd1779"}, + {file = "coverage-7.6.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f39e2f3530ed1626c66e7493be7a8423b023ca852aacdc91fb30162c350d2a92"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:716a78a342679cd1177bc8c2fe957e0ab91405bd43a17094324845200b2fddf4"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:177f01eeaa3aee4a5ffb0d1439c5952b53d5010f86e9d2667963e632e30082cc"}, + {file = "coverage-7.6.8-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:912e95017ff51dc3d7b6e2be158dedc889d9a5cc3382445589ce554f1a34c0ea"}, + {file = "coverage-7.6.8-cp39-cp39-win32.whl", hash = "sha256:4db3ed6a907b555e57cc2e6f14dc3a4c2458cdad8919e40b5357ab9b6db6c43e"}, + {file = "coverage-7.6.8-cp39-cp39-win_amd64.whl", hash = "sha256:428ac484592f780e8cd7b6b14eb568f7c85460c92e2a37cb0c0e5186e1a0d076"}, + {file = "coverage-7.6.8-pp39.pp310-none-any.whl", hash = "sha256:5c52a036535d12590c32c49209e79cabaad9f9ad8aa4cbd875b68c4d67a9cbce"}, + {file = "coverage-7.6.8.tar.gz", hash = "sha256:8b2b8503edb06822c86d82fa64a4a5cb0760bb8f31f26e138ec743f422f37cfc"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + [[package]] name = "datasets" version = "2.20.0" @@ -4311,6 +4388,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-cov" +version = "6.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0"}, + {file = "pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35"}, +] + +[package.dependencies] +coverage = {version = ">=7.5", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -6492,4 +6587,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e455ddfd148ffae77e7b9ac196792b38c7cc545c0866cca2aaae227f4a4201df" +content-hash = "aa504a39622519d5c93f85040a1c7a7708442abf7895f3a76fe842bd1d016072" diff --git a/pyproject.toml b/pyproject.toml index 8ae9295b..ce858dbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ openai = ["openai"] all = ["anthropic", "aws", "google", "groq", "mistral", "openai"] # To install all providers [tool.poetry.group.dev.dependencies] -pytest = "^8.2.2" pre-commit = "^3.7.1" black = "^24.4.2" python-dotenv = "^1.0.1" @@ -45,6 +44,20 @@ sentence-transformers = "^3.0.1" datasets = "^2.20.0" vertexai = "^1.63.0" +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.test.dependencies] +pytest = "^8.2.2" +pytest-cov = "^6.0.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pytest.ini_options] +testpaths="tests" +addopts=[ + "--cov=aisuite", + "--cov-report=term-missing" +] From d678e05b0bfc256f839a6dc3f16559dfbadf6333 Mon Sep 17 00:00:00 2001 From: chris-stokes Date: Sat, 30 Nov 2024 22:03:44 +0000 Subject: [PATCH 06/40] Refactor unittest.TestCase to functional pytest --- .github/workflows/run_pytest.yml | 2 +- tests/client/test_client.py | 302 +++++++++++++------------------ 2 files changed, 131 insertions(+), 173 deletions(-) diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index 0093c348..9630faf7 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -18,7 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install poetry - poetry install + poetry install --with test - name: Test with pytest run: poetry run pytest diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e1949ac..3153dab0 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,72 +1,92 @@ -import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch + +import pytest + from aisuite import Client -class TestClient(unittest.TestCase): - @patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create") - @patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create") - @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") - @patch("aisuite.providers.aws_provider.AwsProvider.chat_completions_create") - @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") - @patch( - "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" - ) - @patch("aisuite.providers.google_provider.GoogleProvider.chat_completions_create") - @patch( - "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create" - ) - def test_client_chat_completions( - self, - mock_fireworks, - mock_google, - mock_anthropic, - mock_azure, - mock_bedrock, - mock_openai, - mock_groq, - mock_mistral, - ): - # Mock responses from providers - mock_openai.return_value = "OpenAI Response" - mock_bedrock.return_value = "AWS Bedrock Response" - mock_azure.return_value = "Azure Response" - mock_anthropic.return_value = "Anthropic Response" - mock_groq.return_value = "Groq Response" - mock_mistral.return_value = "Mistral Response" - mock_google.return_value = "Google Response" - mock_fireworks.return_value = "Fireworks Response" - - # Provider configurations - provider_configs = { - "openai": {"api_key": "test_openai_api_key"}, - "aws": { - "aws_access_key": "test_aws_access_key", - "aws_secret_key": "test_aws_secret_key", - "aws_session_token": "test_aws_session_token", - "aws_region": "us-west-2", - }, - "azure": { - "api_key": "azure-api-key", - "base_url": "https://model.ai.azure.com", - }, - "groq": { - "api_key": "groq-api-key", - }, - "mistral": { - "api_key": "mistral-api-key", - }, - "google": { - "project_id": "test_google_project_id", - "region": "us-west4", - "application_credentials": "test_google_application_credentials", - }, - "fireworks": { - "api_key": "fireworks-api-key", - }, - } - - # Initialize the client +@pytest.fixture(scope="module") +def provider_configs(): + return { + "openai": {"api_key": "test_openai_api_key"}, + "aws": { + "aws_access_key": "test_aws_access_key", + "aws_secret_key": "test_aws_secret_key", + "aws_session_token": "test_aws_session_token", + "aws_region": "us-west-2", + }, + "azure": { + "api_key": "azure-api-key", + "base_url": "https://model.ai.azure.com", + }, + "groq": { + "api_key": "groq-api-key", + }, + "mistral": { + "api_key": "mistral-api-key", + }, + "google": { + "project_id": "test_google_project_id", + "region": "us-west4", + "application_credentials": "test_google_application_credentials", + }, + "fireworks": { + "api_key": "fireworks-api-key", + }, + } + + +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create", + "openai", + "gpt-4o", + ), + ( + "aisuite.providers.mistral_provider.MistralProvider.chat_completions_create", + "mistral", + "mistral-model", + ), + ( + "aisuite.providers.groq_provider.GroqProvider.chat_completions_create", + "groq", + "groq-model", + ), + ( + "aisuite.providers.aws_provider.AwsProvider.chat_completions_create", + "aws", + "claude-v3", + ), + ( + "aisuite.providers.azure_provider.AzureProvider.chat_completions_create", + "azure", + "azure-model", + ), + ( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create", + "anthropic", + "anthropic-model", + ), + ( + "aisuite.providers.google_provider.GoogleProvider.chat_completions_create", + "google", + "google-model", + ), + ( + "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create", + "fireworks", + "fireworks-model", + ), + ], +) +def test_client_chat_completions( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response client = Client() client.configure(provider_configs) messages = [ @@ -74,115 +94,53 @@ def test_client_chat_completions( {"role": "user", "content": "Who won the world series in 2020?"}, ] - # Test OpenAI model - open_ai_model = "openai" + ":" + "gpt-4o" - openai_response = client.chat.completions.create( - open_ai_model, messages=messages - ) - self.assertEqual(openai_response, "OpenAI Response") - mock_openai.assert_called_once() - - # Test AWS Bedrock model - bedrock_model = "aws" + ":" + "claude-v3" - bedrock_response = client.chat.completions.create( - bedrock_model, messages=messages - ) - self.assertEqual(bedrock_response, "AWS Bedrock Response") - mock_bedrock.assert_called_once() - - # Test Azure model - azure_model = "azure" + ":" + "azure-model" - azure_response = client.chat.completions.create(azure_model, messages=messages) - self.assertEqual(azure_response, "Azure Response") - mock_azure.assert_called_once() - - # Test Anthropic model - anthropic_model = "anthropic" + ":" + "anthropic-model" - anthropic_response = client.chat.completions.create( - anthropic_model, messages=messages - ) - self.assertEqual(anthropic_response, "Anthropic Response") - mock_anthropic.assert_called_once() - - # Test Groq model - groq_model = "groq" + ":" + "groq-model" - groq_response = client.chat.completions.create(groq_model, messages=messages) - self.assertEqual(groq_response, "Groq Response") - mock_groq.assert_called_once() - - # Test Mistral model - mistral_model = "mistral" + ":" + "mistral-model" - mistral_response = client.chat.completions.create( - mistral_model, messages=messages - ) - self.assertEqual(mistral_response, "Mistral Response") - mock_mistral.assert_called_once() - - # Test Google model - google_model = "google" + ":" + "google-model" - google_response = client.chat.completions.create( - google_model, messages=messages - ) - self.assertEqual(google_response, "Google Response") - mock_google.assert_called_once() - - # Test Fireworks model - fireworks_model = "fireworks" + ":" + "fireworks-model" - fireworks_response = client.chat.completions.create( - fireworks_model, messages=messages - ) - self.assertEqual(fireworks_response, "Fireworks Response") - mock_fireworks.assert_called_once() - - # Test that new instances of Completion are not created each time we make an inference call. - compl_instance = client.chat.completions - next_compl_instance = client.chat.completions - assert compl_instance is next_compl_instance - - def test_invalid_provider_in_client_config(self): - # Testing an invalid provider name in the configuration - invalid_provider_configs = { - "invalid_provider": {"api_key": "invalid_api_key"}, - } - - # Expect ValueError when initializing Client with invalid provider - with self.assertRaises(ValueError) as context: - client = Client(invalid_provider_configs) - - # Verify the error message - self.assertIn( - "Invalid provider key 'invalid_provider'. Supported providers: ", - str(context.exception), - ) - - @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") - def test_invalid_model_format_in_create(self, mock_openai): - # Valid provider configurations - provider_configs = { - "openai": {"api_key": "test_openai_api_key"}, - } - - # Initialize the client with valid provider - client = Client() - client.configure(provider_configs) + model_str = f"{provider}:{model}" + model_response = client.chat.completions.create(model_str, messages=messages) + assert model_response == expected_response - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a joke."}, - ] - # Invalid model format - invalid_model = "invalidmodel" +def test_invalid_provider_in_client_config(): + # Testing an invalid provider name in the configuration + invalid_provider_configs = { + "invalid_provider": {"api_key": "invalid_api_key"}, + } + + # Expect ValueError when initializing Client with invalid provider and verify message + with pytest.raises( + ValueError, + match=r"Invalid provider key 'invalid_provider'. Supported providers: ", + ): + _ = Client(invalid_provider_configs) - # Expect ValueError when calling create with invalid model format - with self.assertRaises(ValueError) as context: - client.chat.completions.create(invalid_model, messages=messages) - # Verify the error message - self.assertIn( - "Invalid model format. Expected 'provider:model'", str(context.exception) - ) +def test_invalid_model_format_in_create(monkeypatch): + from aisuite.providers.openai_provider import OpenaiProvider + + monkeypatch.setattr( + target=OpenaiProvider, + name="chat_completions_create", + value=Mock(), + ) + # Valid provider configurations + provider_configs = { + "openai": {"api_key": "test_openai_api_key"}, + } -if __name__ == "__main__": - unittest.main() + # Initialize the client with valid provider + client = Client() + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format and verify message + with pytest.raises( + ValueError, match=r"Invalid model format. Expected 'provider:model'" + ): + client.chat.completions.create(invalid_model, messages=messages) From 46816a307e52325de027a3685a39169b126aebca Mon Sep 17 00:00:00 2001 From: chris-stokes Date: Sat, 30 Nov 2024 22:06:09 +0000 Subject: [PATCH 07/40] Rebuild lockfile --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 4f1e3c45..868b37ee 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6587,4 +6587,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "aa504a39622519d5c93f85040a1c7a7708442abf7895f3a76fe842bd1d016072" +content-hash = "c061727cb901c1aa56070c006fe2345e36e6279a5d7f2a3d0cf916e11b381d5a" From 2a859456ed67b8a51052fa24992771c693e6f4c4 Mon Sep 17 00:00:00 2001 From: Neel Patel <38160394+neel6762@users.noreply.github.com> Date: Mon, 2 Dec 2024 07:45:06 -0500 Subject: [PATCH 08/40] Update google.md (#63) Updated the description on where to find the region information. --- guides/google.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guides/google.md b/guides/google.md index eb351bd0..66231dc7 100644 --- a/guides/google.md +++ b/guides/google.md @@ -22,7 +22,7 @@ Set the `GOOGLE_PROJECT_ID` environment variable to the ID of your project. You ### Set your preferred region in an environment variable. -Set the `GOOGLE_REGION` environment variable to the ID of your project. You can find the Project ID by visiting the project dashboard in the "Project Info" section toward the top of the page. +Set the `GOOGLE_REGION` environment variable. You can find the region by going to Project Dashboard under VertexAI side navigation menu, and then scrolling to the bottom of the page. ## Create a Service Account For API Access @@ -89,4 +89,4 @@ response = client.chat.completions.create( print(response.choices[0].message.content) ``` -Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). \ No newline at end of file +Happy coding! If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). From 4da4b70b77daab0a7c87e86392239d3004b089ad Mon Sep 17 00:00:00 2001 From: Kevin Bazira Date: Mon, 2 Dec 2024 15:45:49 +0300 Subject: [PATCH 09/40] Fix dead link on PyPI (#52) The README.md file for aisuite uses a relative URL for the Contributing Guide. While this works well on GitHub, it returns a 404 error when visited on PyPI. This change updates the README.md to include an absolute URL for the Contributing Guide that works well on both GitHub and PyPI. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bd7df95e..20efd447 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ aisuite is released under the MIT License. You are free to use, modify, and dist ## Contributing -If you would like to contribute, please read our [Contributing Guide](CONTRIBUTING.md) and join our [Discord](https://discord.gg/T6Nvn8ExSb) server! +If you would like to contribute, please read our [Contributing Guide](https://github.com/andrewyng/aisuite/blob/main/CONTRIBUTING.md) and join our [Discord](https://discord.gg/T6Nvn8ExSb) server! ## Adding support for a provider We have made easy for a provider or volunteer to add support for a new platform. From 6ef9108325d0256d126a92ad6d93c442902d7017 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:29:57 -0600 Subject: [PATCH 10/40] add explicit dependency for httpx httpx version 0.28.0 removed a deprecated keyword based argument `proxies` that is used in some of the client libraries. This PR pins our version of httpx to 0.27.x to prevent this error. Closes #68 Closes #110 --- poetry.lock | 11 ++++++----- pyproject.toml | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 27aa1cc1..a8854de4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1845,13 +1845,13 @@ test = ["Cython (>=0.29.24,<0.30.0)"] [[package]] name = "httpx" -version = "0.27.0" +version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] [package.dependencies] @@ -1866,6 +1866,7 @@ brotli = ["brotli", "brotlicffi"] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "httpx-sse" @@ -6492,4 +6493,4 @@ openai = ["openai"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "e455ddfd148ffae77e7b9ac196792b38c7cc545c0866cca2aaae227f4a4201df" +content-hash = "bcc204dd665cacb69c2fc9e1b4055ada421c73ef4ac4320044fbc0a2495e235d" diff --git a/pyproject.toml b/pyproject.toml index 8ae9295b..264f308b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ mistralai = { version = "^1.0.3", optional = true } openai = { version = "^1.35.8", optional = true } # Optional dependencies for different providers +httpx = "~0.27.0" [tool.poetry.extras] anthropic = ["anthropic"] aws = ["boto3"] From 7eecd6d4a433c8ae361e8899aa125e80548847bf Mon Sep 17 00:00:00 2001 From: Zoltan Csaki <137233191+snova-zoltanc@users.noreply.github.com> Date: Thu, 5 Dec 2024 10:09:28 -0800 Subject: [PATCH 11/40] Add SambaNova Provider (#54) * add sambanova provider * add documentation * update comments in sambanova provider for review * fix client constructor to pass in the entire config * fix linting error --- .env.sample | 3 ++ aisuite/providers/sambanova_provider.py | 30 ++++++++++++++ guides/README.md | 1 + guides/sambanova.md | 44 +++++++++++++++++++++ tests/providers/test_sambanova_provider.py | 46 ++++++++++++++++++++++ 5 files changed, 124 insertions(+) create mode 100644 aisuite/providers/sambanova_provider.py create mode 100644 guides/sambanova.md create mode 100644 tests/providers/test_sambanova_provider.py diff --git a/.env.sample b/.env.sample index 00826f4b..c753d0e2 100644 --- a/.env.sample +++ b/.env.sample @@ -25,3 +25,6 @@ FIREWORKS_API_KEY= # Together AI TOGETHER_API_KEY= + +# Sambanova +SAMBANOVA_API_KEY= diff --git a/aisuite/providers/sambanova_provider.py b/aisuite/providers/sambanova_provider.py new file mode 100644 index 00000000..75a97311 --- /dev/null +++ b/aisuite/providers/sambanova_provider.py @@ -0,0 +1,30 @@ +import os +from aisuite.provider import Provider +from openai import OpenAI + + +class SambanovaProvider(Provider): + def __init__(self, **config): + """ + Initialize the SambaNova provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("SAMBANOVA_API_KEY")) + if not config["api_key"]: + raise ValueError( + "Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." + ) + + config["base_url"] = "https://api.sambanova.ai/v1/" + # Pass the entire config to the OpenAI client constructor + self.client = OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by Sambanova will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Sambanova API + ) diff --git a/guides/README.md b/guides/README.md index 3079c29c..62bf072c 100644 --- a/guides/README.md +++ b/guides/README.md @@ -9,6 +9,7 @@ Here're the instructions for: - [Google](google.md) - [Hugging Face](huggingface.md) - [OpenAI](openai.md) +- [SambaNova](sambanova.md) Unless otherwise stated, these guides have not been endorsed by the providers. diff --git a/guides/sambanova.md b/guides/sambanova.md new file mode 100644 index 00000000..6b331c2f --- /dev/null +++ b/guides/sambanova.md @@ -0,0 +1,44 @@ +# Sambanova + +To use Sambanova with `aisuite`, you’ll need a [Sambanova Cloud](https://cloud.sambanova.ai/) account. After logging in, go to the [API](https://cloud.sambanova.ai/apis) section and generate a new key. Once you have your key, add it to your environment as follows: + +```shell +export SAMBANOVA_API_KEY="your-sambanova-api-key" +``` + +## Create a Chat Completion + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "sambanova" +model_id = "Meta-Llama-3.1-405B-Instruct" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/tests/providers/test_sambanova_provider.py b/tests/providers/test_sambanova_provider.py new file mode 100644 index 00000000..b5c649ec --- /dev/null +++ b/tests/providers/test_sambanova_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.sambanova_provider import SambanovaProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("SAMBANOVA_API_KEY", "test-api-key") + + +def test_sambanova_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = SambanovaProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content From d8cbbe8a80d2e18ca65772052f948dc6a2f3baf1 Mon Sep 17 00:00:00 2001 From: Akim Tsvigun Date: Fri, 6 Dec 2024 16:54:32 +0100 Subject: [PATCH 12/40] Integration with Nebius AI Studio added --- aisuite/providers/nebius_provider.py | 65 +++++++++++++++++++++++++ examples/QnA_with_pdf.ipynb | 20 +++++++- examples/client.ipynb | 13 +++++ tests/client/test_client.py | 16 ++++++ tests/providers/test_nebius_provider.py | 19 ++++++++ 5 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 aisuite/providers/nebius_provider.py create mode 100644 tests/providers/test_nebius_provider.py diff --git a/aisuite/providers/nebius_provider.py b/aisuite/providers/nebius_provider.py new file mode 100644 index 00000000..a90e8ef4 --- /dev/null +++ b/aisuite/providers/nebius_provider.py @@ -0,0 +1,65 @@ +import os +import httpx +from aisuite.provider import Provider, LLMError +from aisuite.framework import ChatCompletionResponse + + +class NebiusProvider(Provider): + """ + Nebius AI Studio Provider using httpx for direct API calls. + """ + + BASE_URL = "https://api.studio.nebius.ai/v1/chat/completions" + + def __init__(self, **config): + """ + Initialize the Nebius AI Studio provider with the given configuration. + The API key is fetched from the config or environment variables. + """ + self.api_key = config.get("api_key", os.getenv("NEBIUS_API_KEY")) + if not self.api_key: + raise ValueError( + "Nebius AI Studio API key is missing. Please provide it in the config or set the NEBIUS_API_KEY environment variable. You can get your API key at https://studio.nebius.ai/settings/api-keys" + ) + + # Optionally set a custom timeout (default to 30s) + self.timeout = config.get("timeout", 30) + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the Nebius AI Studio chat completions endpoint using httpx. + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + data = { + "model": model, + "messages": messages, + **kwargs, # Pass any additional arguments to the API + } + + try: + # Make the request to the Nebius AI Studio endpoint. + response = httpx.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + except httpx.HTTPStatusError as http_err: + raise LLMError(f"Nebius AI Studio request failed: {http_err}") + except Exception as e: + raise LLMError(f"An error occurred: {e}") + + # Return the normalized response + return self._normalize_response(response.json()) + + def _normalize_response(self, response_data): + """ + Normalize the response to a common format (ChatCompletionResponse). + """ + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response_data["choices"][0][ + "message" + ]["content"] + return normalized_response diff --git a/examples/QnA_with_pdf.ipynb b/examples/QnA_with_pdf.ipynb index 4fbf0ba0..bfcb8b78 100644 --- a/examples/QnA_with_pdf.ipynb +++ b/examples/QnA_with_pdf.ipynb @@ -102,7 +102,6 @@ "metadata": {}, "outputs": [], "source": [ - "import aisuite as ai\n", "client = ai.Client()\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant. Answer the question only based on the below text.\"},\n", @@ -180,6 +179,25 @@ "print(response.choices[0].message.content)" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Yes, the price of organic avocados is higher than non-organic avocados. According to the text, the average price of organic avocados is generally 35-40% higher than conventional avocados.\n" + ] + } + ], + "source": [ + "nebius_model = \"nebius:meta-llama/Meta-Llama-3.1-8B-Instruct-fast\"\n", + "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", + "print(response.choices[0].message.content)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/client.ipynb b/examples/client.ipynb index e99f2f50..a26f3e67 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -61,6 +61,7 @@ " 'AWS_ACCESS_KEY_ID': 'xxx',\n", " 'AWS_SECRET_ACCESS_KEY': 'xxx',\n", " 'ANTHROPIC_API_KEY': 'xxx',\n", + " 'NEBIUS_API_KEY': 'xxx',\n", "}\n", "\n", "# Configure environment\n", @@ -208,6 +209,18 @@ "print(response.choices[0].message.content)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f38d033a-a580-4239-9176-27f3d53e7fe1", + "metadata": {}, + "outputs": [], + "source": [ + "nebius_model = \"nebius:meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", + "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", + "print(response.choices[0].message.content)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2e1949ac..9f76ee51 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -16,8 +16,12 @@ class TestClient(unittest.TestCase): @patch( "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create" ) + @patch( + "aisuite.providers.nebius_provider.NebiusProvider.chat_completions_create" + ) def test_client_chat_completions( self, + mock_nebius, mock_fireworks, mock_google, mock_anthropic, @@ -36,6 +40,7 @@ def test_client_chat_completions( mock_mistral.return_value = "Mistral Response" mock_google.return_value = "Google Response" mock_fireworks.return_value = "Fireworks Response" + mock_nebius.return_value = "Nebius Response" # Provider configurations provider_configs = { @@ -64,6 +69,9 @@ def test_client_chat_completions( "fireworks": { "api_key": "fireworks-api-key", }, + "nebius": { + "api_key": "nebius-api-key", + }, } # Initialize the client @@ -134,6 +142,14 @@ def test_client_chat_completions( self.assertEqual(fireworks_response, "Fireworks Response") mock_fireworks.assert_called_once() + # Test Nebius model + nebius_model = "nebius" + ":" + "nebius-model" + nebius_response = client.chat.completions.create( + nebius_model, messages=messages + ) + self.assertEqual(nebius_response, "Nebius Response") + mock_nebius.assert_called_once() + # Test that new instances of Completion are not created each time we make an inference call. compl_instance = client.chat.completions next_compl_instance = client.chat.completions diff --git a/tests/providers/test_nebius_provider.py b/tests/providers/test_nebius_provider.py new file mode 100644 index 00000000..b581439d --- /dev/null +++ b/tests/providers/test_nebius_provider.py @@ -0,0 +1,19 @@ +import pytest +from unittest.mock import patch, MagicMock + +from aisuite.providers.nebius_provider import NebiusProvider + +def test_nebius_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "We are testing you. Please say 'One two three' and nothing more." + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "Qwen/Qwen2.5-32B-Instruct-fast" + chosen_top_p = 0.01 + response_text_content = "One two three" + + provider = NebiusProvider() + print(provider.api_key) + response = provider.chat_completions_create(model=selected_model, messages=message_history, top_p=chosen_top_p) + + assert response.choices[0].message.content == response_text_content From 7f5d9bfb37594625db497b1999a0581a1ec1a4a2 Mon Sep 17 00:00:00 2001 From: Akim Tsvigun Date: Fri, 6 Dec 2024 16:59:16 +0100 Subject: [PATCH 13/40] test client conflict resolved --- tests/client/test_client.py | 326 +++++++++++++++--------------------- 1 file changed, 138 insertions(+), 188 deletions(-) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 9f76ee51..ed8551c2 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,80 +1,100 @@ -import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch + +import pytest + from aisuite import Client -class TestClient(unittest.TestCase): - @patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create") - @patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create") - @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") - @patch("aisuite.providers.aws_provider.AwsProvider.chat_completions_create") - @patch("aisuite.providers.azure_provider.AzureProvider.chat_completions_create") - @patch( - "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create" - ) - @patch("aisuite.providers.google_provider.GoogleProvider.chat_completions_create") - @patch( - "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create" - ) - @patch( - "aisuite.providers.nebius_provider.NebiusProvider.chat_completions_create" - ) - def test_client_chat_completions( - self, - mock_nebius, - mock_fireworks, - mock_google, - mock_anthropic, - mock_azure, - mock_bedrock, - mock_openai, - mock_groq, - mock_mistral, - ): - # Mock responses from providers - mock_openai.return_value = "OpenAI Response" - mock_bedrock.return_value = "AWS Bedrock Response" - mock_azure.return_value = "Azure Response" - mock_anthropic.return_value = "Anthropic Response" - mock_groq.return_value = "Groq Response" - mock_mistral.return_value = "Mistral Response" - mock_google.return_value = "Google Response" - mock_fireworks.return_value = "Fireworks Response" - mock_nebius.return_value = "Nebius Response" - - # Provider configurations - provider_configs = { - "openai": {"api_key": "test_openai_api_key"}, - "aws": { - "aws_access_key": "test_aws_access_key", - "aws_secret_key": "test_aws_secret_key", - "aws_session_token": "test_aws_session_token", - "aws_region": "us-west-2", - }, - "azure": { - "api_key": "azure-api-key", - "base_url": "https://model.ai.azure.com", - }, - "groq": { - "api_key": "groq-api-key", - }, - "mistral": { - "api_key": "mistral-api-key", - }, - "google": { - "project_id": "test_google_project_id", - "region": "us-west4", - "application_credentials": "test_google_application_credentials", - }, - "fireworks": { - "api_key": "fireworks-api-key", - }, - "nebius": { - "api_key": "nebius-api-key", - }, - } - - # Initialize the client +@pytest.fixture(scope="module") +def provider_configs(): + return { + "openai": {"api_key": "test_openai_api_key"}, + "aws": { + "aws_access_key": "test_aws_access_key", + "aws_secret_key": "test_aws_secret_key", + "aws_session_token": "test_aws_session_token", + "aws_region": "us-west-2", + }, + "azure": { + "api_key": "azure-api-key", + "base_url": "https://model.ai.azure.com", + }, + "groq": { + "api_key": "groq-api-key", + }, + "mistral": { + "api_key": "mistral-api-key", + }, + "google": { + "project_id": "test_google_project_id", + "region": "us-west4", + "application_credentials": "test_google_application_credentials", + }, + "fireworks": { + "api_key": "fireworks-api-key", + }, + "nebius": { + "api_key": "nebius-api-key", + }, + } + + +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create", + "openai", + "gpt-4o", + ), + ( + "aisuite.providers.mistral_provider.MistralProvider.chat_completions_create", + "mistral", + "mistral-model", + ), + ( + "aisuite.providers.groq_provider.GroqProvider.chat_completions_create", + "groq", + "groq-model", + ), + ( + "aisuite.providers.aws_provider.AwsProvider.chat_completions_create", + "aws", + "claude-v3", + ), + ( + "aisuite.providers.azure_provider.AzureProvider.chat_completions_create", + "azure", + "azure-model", + ), + ( + "aisuite.providers.anthropic_provider.AnthropicProvider.chat_completions_create", + "anthropic", + "anthropic-model", + ), + ( + "aisuite.providers.google_provider.GoogleProvider.chat_completions_create", + "google", + "google-model", + ), + ( + "aisuite.providers.fireworks_provider.FireworksProvider.chat_completions_create", + "fireworks", + "fireworks-model", + ), + ( + "aisuite.providers.nebius_provider.NebiusProvider.chat_completions_create", + "nebius", + "nebius-model", + ), + ], +) +def test_client_chat_completions( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response client = Client() client.configure(provider_configs) messages = [ @@ -82,123 +102,53 @@ def test_client_chat_completions( {"role": "user", "content": "Who won the world series in 2020?"}, ] - # Test OpenAI model - open_ai_model = "openai" + ":" + "gpt-4o" - openai_response = client.chat.completions.create( - open_ai_model, messages=messages - ) - self.assertEqual(openai_response, "OpenAI Response") - mock_openai.assert_called_once() - - # Test AWS Bedrock model - bedrock_model = "aws" + ":" + "claude-v3" - bedrock_response = client.chat.completions.create( - bedrock_model, messages=messages - ) - self.assertEqual(bedrock_response, "AWS Bedrock Response") - mock_bedrock.assert_called_once() - - # Test Azure model - azure_model = "azure" + ":" + "azure-model" - azure_response = client.chat.completions.create(azure_model, messages=messages) - self.assertEqual(azure_response, "Azure Response") - mock_azure.assert_called_once() - - # Test Anthropic model - anthropic_model = "anthropic" + ":" + "anthropic-model" - anthropic_response = client.chat.completions.create( - anthropic_model, messages=messages - ) - self.assertEqual(anthropic_response, "Anthropic Response") - mock_anthropic.assert_called_once() - - # Test Groq model - groq_model = "groq" + ":" + "groq-model" - groq_response = client.chat.completions.create(groq_model, messages=messages) - self.assertEqual(groq_response, "Groq Response") - mock_groq.assert_called_once() - - # Test Mistral model - mistral_model = "mistral" + ":" + "mistral-model" - mistral_response = client.chat.completions.create( - mistral_model, messages=messages - ) - self.assertEqual(mistral_response, "Mistral Response") - mock_mistral.assert_called_once() - - # Test Google model - google_model = "google" + ":" + "google-model" - google_response = client.chat.completions.create( - google_model, messages=messages - ) - self.assertEqual(google_response, "Google Response") - mock_google.assert_called_once() - - # Test Fireworks model - fireworks_model = "fireworks" + ":" + "fireworks-model" - fireworks_response = client.chat.completions.create( - fireworks_model, messages=messages - ) - self.assertEqual(fireworks_response, "Fireworks Response") - mock_fireworks.assert_called_once() - - # Test Nebius model - nebius_model = "nebius" + ":" + "nebius-model" - nebius_response = client.chat.completions.create( - nebius_model, messages=messages - ) - self.assertEqual(nebius_response, "Nebius Response") - mock_nebius.assert_called_once() - - # Test that new instances of Completion are not created each time we make an inference call. - compl_instance = client.chat.completions - next_compl_instance = client.chat.completions - assert compl_instance is next_compl_instance - - def test_invalid_provider_in_client_config(self): - # Testing an invalid provider name in the configuration - invalid_provider_configs = { - "invalid_provider": {"api_key": "invalid_api_key"}, - } - - # Expect ValueError when initializing Client with invalid provider - with self.assertRaises(ValueError) as context: - client = Client(invalid_provider_configs) - - # Verify the error message - self.assertIn( - "Invalid provider key 'invalid_provider'. Supported providers: ", - str(context.exception), - ) - - @patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create") - def test_invalid_model_format_in_create(self, mock_openai): - # Valid provider configurations - provider_configs = { - "openai": {"api_key": "test_openai_api_key"}, - } - - # Initialize the client with valid provider - client = Client() - client.configure(provider_configs) + model_str = f"{provider}:{model}" + model_response = client.chat.completions.create(model_str, messages=messages) + assert model_response == expected_response - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a joke."}, - ] - # Invalid model format - invalid_model = "invalidmodel" +def test_invalid_provider_in_client_config(): + # Testing an invalid provider name in the configuration + invalid_provider_configs = { + "invalid_provider": {"api_key": "invalid_api_key"}, + } + + # Expect ValueError when initializing Client with invalid provider and verify message + with pytest.raises( + ValueError, + match=r"Invalid provider key 'invalid_provider'. Supported providers: ", + ): + _ = Client(invalid_provider_configs) - # Expect ValueError when calling create with invalid model format - with self.assertRaises(ValueError) as context: - client.chat.completions.create(invalid_model, messages=messages) - # Verify the error message - self.assertIn( - "Invalid model format. Expected 'provider:model'", str(context.exception) - ) +def test_invalid_model_format_in_create(monkeypatch): + from aisuite.providers.openai_provider import OpenaiProvider + + monkeypatch.setattr( + target=OpenaiProvider, + name="chat_completions_create", + value=Mock(), + ) + # Valid provider configurations + provider_configs = { + "openai": {"api_key": "test_openai_api_key"}, + } -if __name__ == "__main__": - unittest.main() + # Initialize the client with valid provider + client = Client() + client.configure(provider_configs) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + + # Invalid model format + invalid_model = "invalidmodel" + + # Expect ValueError when calling create with invalid model format and verify message + with pytest.raises( + ValueError, match=r"Invalid model format. Expected 'provider:model'" + ): + client.chat.completions.create(invalid_model, messages=messages) \ No newline at end of file From d85357b90ed6f6fffb422197bc3b83a9fcc75d8e Mon Sep 17 00:00:00 2001 From: Zoltan Csaki Date: Fri, 6 Dec 2024 13:26:19 -0800 Subject: [PATCH 14/40] add sambanova as valid provider in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bd7df95e..8784fb6b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Simple, unified interface to multiple Generative AI providers. `aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future. Currently supported providers are - -OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace and Ollama. +OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama and Sambanova. To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider. ## Installation From ad28abeb48703d969ac5a99f7318f5ba81123031 Mon Sep 17 00:00:00 2001 From: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Date: Sat, 7 Dec 2024 07:46:13 +0530 Subject: [PATCH 15/40] Add xAI support (#53) * add: xAI API support * Create xai.md * Update Guides README.md --- .env.sample | 3 ++ aisuite/providers/xai_provider.py | 65 +++++++++++++++++++++++++++++++ guides/README.md | 1 + guides/xai.md | 33 ++++++++++++++++ 4 files changed, 102 insertions(+) create mode 100644 aisuite/providers/xai_provider.py create mode 100644 guides/xai.md diff --git a/.env.sample b/.env.sample index c753d0e2..a195a847 100644 --- a/.env.sample +++ b/.env.sample @@ -26,5 +26,8 @@ FIREWORKS_API_KEY= # Together AI TOGETHER_API_KEY= +# xAI +XAI_API_KEY= + # Sambanova SAMBANOVA_API_KEY= diff --git a/aisuite/providers/xai_provider.py b/aisuite/providers/xai_provider.py new file mode 100644 index 00000000..53e8d831 --- /dev/null +++ b/aisuite/providers/xai_provider.py @@ -0,0 +1,65 @@ +import os +import httpx +from aisuite.provider import Provider, LLMError +from aisuite.framework import ChatCompletionResponse + + +class XaiProvider(Provider): + """ + xAI Provider using httpx for direct API calls. + """ + + BASE_URL = "https://api.x.ai/v1/chat/completions" + + def __init__(self, **config): + """ + Initialize the xAI provider with the given configuration. + The API key is fetched from the config or environment variables. + """ + self.api_key = config.get("api_key", os.getenv("XAI_API_KEY")) + if not self.api_key: + raise ValueError( + "xAI API key is missing. Please provide it in the config or set the XAI_API_KEY environment variable." + ) + + # Optionally set a custom timeout (default to 30s) + self.timeout = config.get("timeout", 30) + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the xAI chat completions endpoint using httpx. + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + data = { + "model": model, + "messages": messages, + **kwargs, # Pass any additional arguments to the API + } + + try: + # Make the request to xAI endpoint. + response = httpx.post( + self.BASE_URL, json=data, headers=headers, timeout=self.timeout + ) + response.raise_for_status() + except httpx.HTTPStatusError as http_err: + raise LLMError(f"xAI request failed: {http_err}") + except Exception as e: + raise LLMError(f"An error occurred: {e}") + + # Return the normalized response + return self._normalize_response(response.json()) + + def _normalize_response(self, response_data): + """ + Normalize the response to a common format (ChatCompletionResponse). + """ + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response_data["choices"][0][ + "message" + ]["content"] + return normalized_response diff --git a/guides/README.md b/guides/README.md index 62bf072c..89f4bd2c 100644 --- a/guides/README.md +++ b/guides/README.md @@ -10,6 +10,7 @@ Here're the instructions for: - [Hugging Face](huggingface.md) - [OpenAI](openai.md) - [SambaNova](sambanova.md) +- [xAI](xai.md) Unless otherwise stated, these guides have not been endorsed by the providers. diff --git a/guides/xai.md b/guides/xai.md new file mode 100644 index 00000000..7129dd99 --- /dev/null +++ b/guides/xai.md @@ -0,0 +1,33 @@ +# xAI + +To use xAI with `aisuite`, you’ll need an [API key](https://console.x.ai/). Generate a new key and once you have your key, add it to your environment as follows: + +```shell +export XAI_API_KEY="your-xai-api-key" +``` + +## Create a Chat Completion + +Sample code: +```python +import aisuite as ai +client = ai.Client() + +models = ["xai:grok-beta"] + +messages = [ + {"role": "system", "content": "Respond in Pirate English."}, + {"role": "user", "content": "Tell me a joke."}, +] + +for model in models: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0.75 + ) + print(response.choices[0].message.content) + +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). From 8c358d61cbf8156933025bb787ec2359b22f9f1b Mon Sep 17 00:00:00 2001 From: Lucain Date: Sat, 7 Dec 2024 03:37:12 +0100 Subject: [PATCH 16/40] Support `HF_TOKEN` environment variable in huggingface_provider.py (#59) * Support `HF_TOKEN` environment variable in huggingface_provider.py * Update huggingface.md with new env variable for token. * Update .env.sample for HF token env var. --- .env.sample | 2 +- aisuite/providers/huggingface_provider.py | 4 ++-- guides/huggingface.md | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.env.sample b/.env.sample index a195a847..ccc308f3 100644 --- a/.env.sample +++ b/.env.sample @@ -18,7 +18,7 @@ GOOGLE_REGION= GOOGLE_PROJECT_ID= # Hugging Face token -HUGGINGFACE_TOKEN= +HF_TOKEN= # Fireworks FIREWORKS_API_KEY= diff --git a/aisuite/providers/huggingface_provider.py b/aisuite/providers/huggingface_provider.py index 5a1bb902..de989f7d 100644 --- a/aisuite/providers/huggingface_provider.py +++ b/aisuite/providers/huggingface_provider.py @@ -19,10 +19,10 @@ def __init__(self, **config): The token is fetched from the config or environment variables. """ # Ensure API key is provided either in config or via environment variable - self.token = config.get("token") or os.getenv("HUGGINGFACE_TOKEN") + self.token = config.get("token") or os.getenv("HF_TOKEN") if not self.token: raise ValueError( - "Hugging Face token is missing. Please provide it in the config or set the HUGGINGFACE_TOKEN environment variable." + "Hugging Face token is missing. Please provide it in the config or set the HF_TOKEN environment variable." ) # Optionally set a custom timeout (default to 30s) diff --git a/guides/huggingface.md b/guides/huggingface.md index 11bd9297..840c3925 100644 --- a/guides/huggingface.md +++ b/guides/huggingface.md @@ -18,7 +18,7 @@ After setting up your model, you'll need to gather the following information: Set the following environment variables to make authentication and requests easy: ```shell -export HUGGINGFACE_TOKEN="your-api-token" +export HF_TOKEN="your-api-token" ``` ## Create a Chat Completion From d7335a97facec22c2751f82ebcbf858876d47936 Mon Sep 17 00:00:00 2001 From: Aditya Rana <42575044+ranaaditya@users.noreply.github.com> Date: Sun, 8 Dec 2024 00:52:31 +0530 Subject: [PATCH 17/40] Fix a typo in message.py (#60) --- aisuite/framework/message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aisuite/framework/message.py b/aisuite/framework/message.py index 5aa7f822..bf07688d 100644 --- a/aisuite/framework/message.py +++ b/aisuite/framework/message.py @@ -1,4 +1,4 @@ -"""Interface to hold contents of api responses when they do not conform to the OpenAI style response""" +"""Interface to hold contents of api responses when they do not confirm to the OpenAI style response""" class Message: From a9fb16bc1623620fdbd342bef184abc336c5dd17 Mon Sep 17 00:00:00 2001 From: Hatice Ozen <139392640+hozen-groq@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:51:10 -0500 Subject: [PATCH 18/40] Add groq guide (#74) --- guides/groq.md | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 guides/groq.md diff --git a/guides/groq.md b/guides/groq.md new file mode 100644 index 00000000..96b50f1e --- /dev/null +++ b/guides/groq.md @@ -0,0 +1,39 @@ +# Groq + +To use Groq with `aisuite`, you’ll need a free [Groq account](https://console.groq.com/). After logging in, go to the [API Keys](https://console.groq.com/keys) section in your account settings and generate a new Groq API key. Once you have your key, add it to your environment as follows: + +```shell +export GROQ_API_KEY="your-groq-api-key" +``` + +## Create a Python Chat Completion + +1. First, install the `groq` Python client library: + +```shell +pip install groq +``` + +2. Now you can simply create your first chat completion with the following example code or customize by swapoping out the `model_id` with any of the other available [models powered by Groq](https://console.groq.com/docs/models) and `messages` array with whatever you'd like: +```python +import aisuite as ai +client = ai.Client() + +provider = "groq" +model_id = "llama-3.2-3b-preview" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). From 148a21ec0226ffc125d131f828af4e4b46b33321 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 8 Dec 2024 05:05:31 +0900 Subject: [PATCH 19/40] Update in client.ipynb. Fix typos (#75) --- examples/client.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/client.ipynb b/examples/client.ipynb index e99f2f50..54b43223 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -122,7 +122,7 @@ "source": [ "# IMP NOTE: Azure expects model endpoint to be passed in the format of \"azure:\".\n", "# The model name is the deployment name in Project/Deployments.\n", - "# In the exmaple below, the model is \"mistral-large-2407\", but the name given to the\n", + "# In the example below, the model is \"mistral-large-2407\", but the name given to the\n", "# deployment is \"aisuite-mistral-large-2407\" under the deployments section in Azure.\n", "client.configure({\"azure\" : {\n", " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", @@ -142,7 +142,7 @@ "source": [ "# HuggingFace expects the model to be passed in the format of \"huggingface:\".\n", "# The model name is the full name of the model in HuggingFace.\n", - "# In the exmaple below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n", + "# In the example below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n", "# The model is deployed as serverless inference endpoint in HuggingFace.\n", "hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n", "response = client.chat.completions.create(model=hf_model, messages=messages)\n", @@ -159,7 +159,7 @@ "\n", "# Groq expects the model to be passed in the format of \"groq:\".\n", "# The model name is the full name of the model in Groq.\n", - "# In the exmaple below, the model is \"llama3-8b-8192\".\n", + "# In the example below, the model is \"llama3-8b-8192\".\n", "groq_llama3_8b = \"groq:llama3-8b-8192\"\n", "# groq_llama3_70b = \"groq:llama3-70b-8192\"\n", "response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n", From f24d9e558741c39668d3d4feb9333de54d0bda01 Mon Sep 17 00:00:00 2001 From: Akim Tsvigun Date: Tue, 10 Dec 2024 10:34:42 +0100 Subject: [PATCH 20/40] Nebius client standardized with OpenAI client --- aisuite/providers/nebius_provider.py | 66 ++++++------------------- examples/client.ipynb | 4 +- tests/providers/test_nebius_provider.py | 40 ++++++++++++--- 3 files changed, 51 insertions(+), 59 deletions(-) diff --git a/aisuite/providers/nebius_provider.py b/aisuite/providers/nebius_provider.py index a90e8ef4..c558a9ce 100644 --- a/aisuite/providers/nebius_provider.py +++ b/aisuite/providers/nebius_provider.py @@ -1,65 +1,31 @@ import os -import httpx -from aisuite.provider import Provider, LLMError -from aisuite.framework import ChatCompletionResponse +from aisuite.provider import Provider +from openai import Client -class NebiusProvider(Provider): - """ - Nebius AI Studio Provider using httpx for direct API calls. - """ +BASE_URL = "https://api.studio.nebius.ai/v1" - BASE_URL = "https://api.studio.nebius.ai/v1/chat/completions" +class NebiusProvider(Provider): def __init__(self, **config): """ Initialize the Nebius AI Studio provider with the given configuration. - The API key is fetched from the config or environment variables. + Pass the entire configuration dictionary to the OpenAI client constructor. """ - self.api_key = config.get("api_key", os.getenv("NEBIUS_API_KEY")) - if not self.api_key: + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("NEBIUS_API_KEY")) + if not config["api_key"]: raise ValueError( "Nebius AI Studio API key is missing. Please provide it in the config or set the NEBIUS_API_KEY environment variable. You can get your API key at https://studio.nebius.ai/settings/api-keys" ) - # Optionally set a custom timeout (default to 30s) - self.timeout = config.get("timeout", 30) + config["base_url"] = BASE_URL + # Pass the entire config to the OpenAI client constructor + self.client = Client(**config) def chat_completions_create(self, model, messages, **kwargs): - """ - Makes a request to the Nebius AI Studio chat completions endpoint using httpx. - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } - - data = { - "model": model, - "messages": messages, - **kwargs, # Pass any additional arguments to the API - } - - try: - # Make the request to the Nebius AI Studio endpoint. - response = httpx.post( - self.BASE_URL, json=data, headers=headers, timeout=self.timeout - ) - response.raise_for_status() - except httpx.HTTPStatusError as http_err: - raise LLMError(f"Nebius AI Studio request failed: {http_err}") - except Exception as e: - raise LLMError(f"An error occurred: {e}") - - # Return the normalized response - return self._normalize_response(response.json()) - - def _normalize_response(self, response_data): - """ - Normalize the response to a common format (ChatCompletionResponse). - """ - normalized_response = ChatCompletionResponse() - normalized_response.choices[0].message.content = response_data["choices"][0][ - "message" - ]["content"] - return normalized_response + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Nebius API + ) diff --git a/examples/client.ipynb b/examples/client.ipynb index a26f3e67..2ad93789 100644 --- a/examples/client.ipynb +++ b/examples/client.ipynb @@ -216,7 +216,7 @@ "metadata": {}, "outputs": [], "source": [ - "nebius_model = \"nebius:meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", + "nebius_model = \"nebius:Qwen/Qwen2.5-1.5B-Instruct\"\n", "response = client.chat.completions.create(model=nebius_model, messages=messages, top_p=0.01)\n", "print(response.choices[0].message.content)" ] @@ -267,4 +267,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/tests/providers/test_nebius_provider.py b/tests/providers/test_nebius_provider.py index b581439d..8e969ea5 100644 --- a/tests/providers/test_nebius_provider.py +++ b/tests/providers/test_nebius_provider.py @@ -3,17 +3,43 @@ from aisuite.providers.nebius_provider import NebiusProvider + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("NEBIUS_API_KEY", "test-api-key") + + def test_nebius_provider(): """High-level test that the provider is initialized and chat completions are requested successfully.""" - user_greeting = "We are testing you. Please say 'One two three' and nothing more." + user_greeting = "Hello!" message_history = [{"role": "user", "content": user_greeting}] - selected_model = "Qwen/Qwen2.5-32B-Instruct-fast" - chosen_top_p = 0.01 - response_text_content = "One two three" + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" provider = NebiusProvider() - print(provider.api_key) - response = provider.chat_completions_create(model=selected_model, messages=message_history, top_p=chosen_top_p) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) - assert response.choices[0].message.content == response_text_content + assert response.choices[0].message.content == response_text_content From 32f93420f37bca6b6aaf348436d5e790b8db86c7 Mon Sep 17 00:00:00 2001 From: Dax Date: Tue, 10 Dec 2024 16:04:36 -0500 Subject: [PATCH 21/40] Normalize response, update tests and readme --- aisuite/providers/watsonx_provider.py | 22 ++++++++++++++-------- guides/watsonx.md | 18 +++++++++--------- tests/providers/test_watsonx_provider.py | 16 ++++++---------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/aisuite/providers/watsonx_provider.py b/aisuite/providers/watsonx_provider.py index 0900479b..5a4be042 100644 --- a/aisuite/providers/watsonx_provider.py +++ b/aisuite/providers/watsonx_provider.py @@ -2,9 +2,7 @@ import os from ibm_watsonx_ai import Credentials from ibm_watsonx_ai.foundation_models import ModelInference -from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams - -DEFAULT_TEMPERATURE = 0.7 +from aisuite.framework import ChatCompletionResponse class WatsonxProvider(Provider): @@ -23,11 +21,19 @@ def __init__(self, **config): def chat_completions_create(self, model, messages, **kwargs): model = ModelInference( model_id=model, - params={ - GenParams.TEMPERATURE: kwargs.get("temperature", DEFAULT_TEMPERATURE), - }, - credentials=Credentials(api_key=self.api_key, url=self.service_url), + credentials=Credentials( + api_key=self.api_key, + url=self.service_url, + ), project_id=self.project_id, ) - return model.chat(prompt=messages, **kwargs) + res = model.chat(messages=messages, params=kwargs) + return self.normalize_response(res) + + def normalize_response(self, response): + openai_response = ChatCompletionResponse() + openai_response.choices[0].message.content = response["choices"][0]["message"][ + "content" + ] + return openai_response diff --git a/guides/watsonx.md b/guides/watsonx.md index 3353666d..c2d4121b 100644 --- a/guides/watsonx.md +++ b/guides/watsonx.md @@ -8,25 +8,25 @@ A a step-by-step guide to set up Watsonx with the `aisuite` library, enabling yo 1. Visit [IBM Watsonx](https://www.ibm.com/watsonx). 2. Sign up for a new account or log in with your existing IBM credentials. -3. Once logged in, navigate to the **Watsonx Dashboard**. +3. Once logged in, navigate to the **Watsonx Dashboard** () --- ### Step 2: Obtain API Credentials 1. **Generate an API Key**: - - Go to the **API Keys** section in your Watsonx account settings. - - Click on **Create API Key**. - - Provide a name for your API key (e.g., `MyWatsonxKey`). - - Click **Generate**, then download or copy the API key. **Keep this key secure!** + - Go to IAM > API keys and create a new API key () + - Copy the API key. This is your `WATSONX_API_KEY`. 2. **Locate the Service URL**: - - Go to the **Endpoints** section in the Watsonx dashboard. - - Find the URL corresponding to your service and note it. This is your `WATSONX_SERVICE_URL`. + - Your service URL is based on the region where your service is hosted. + - Pick one from the list here + - Copy the service URL. This is your `WATSONX_SERVICE_URL`. 3. **Get the Project ID**: - - Navigate to the **Projects** tab in the dashboard. - - Select the project you want to use. + - Go to the **Watsonx Dashboard** () + - Under the **Projects** section, If you don't have a sandbox project, create a new project. + - Navigate to the **Manage** tab and find the **Project ID**. - Copy the **Project ID**. This will serve as your `WATSONX_PROJECT_ID`. --- diff --git a/tests/providers/test_watsonx_provider.py b/tests/providers/test_watsonx_provider.py index 4fc22555..8e7123a7 100644 --- a/tests/providers/test_watsonx_provider.py +++ b/tests/providers/test_watsonx_provider.py @@ -1,8 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from ibm_watsonx_ai import Credentials -from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams +from ibm_watsonx_ai.metanames import GenChatParamsMetaNames as GenChatParams from aisuite.providers.watsonx_provider import WatsonxProvider @@ -25,10 +24,7 @@ def test_watsonx_provider(): response_text_content = "mocked-text-response-from-model" provider = WatsonxProvider() - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message = MagicMock() - mock_response.choices[0].message.content = response_text_content + mock_response = {"choices": [{"message": {"content": response_text_content}}]} with patch( "aisuite.providers.watsonx_provider.ModelInference" @@ -47,17 +43,17 @@ def test_watsonx_provider(): mock_model_inference.assert_called_once() args, kwargs = mock_model_inference.call_args assert kwargs["model_id"] == selected_model - assert kwargs["params"] == {GenParams.TEMPERATURE: chosen_temperature} + assert kwargs["project_id"] == provider.project_id # Assert that the credentials have the correct API key and service URL. credentials = kwargs["credentials"] assert credentials.api_key == provider.api_key assert credentials.url == provider.service_url - # Assert that chat was called with correct history and temperature. + # Assert that chat was called with correct history and params mock_model.chat.assert_called_once_with( - prompt=message_history, - temperature=chosen_temperature, + messages=message_history, + params={GenChatParams.TEMPERATURE: chosen_temperature}, ) assert response.choices[0].message.content == response_text_content From 763996153e3c6068465bd73a367a20f6ca79cefc Mon Sep 17 00:00:00 2001 From: Rohit Prasad Date: Tue, 10 Dec 2024 22:09:54 -0800 Subject: [PATCH 22/40] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index de006365..b6fbfe09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "aisuite" version = "0.1.6" description = "Uniform access layer for LLMs" -authors = ["Andrew Ng"] +authors = ["Andrew Ng, Rohit P"] readme = "README.md" [tool.poetry.dependencies] From 74f1c0417c040cf1ab9e9dd2c3e45f5c7be004aa Mon Sep 17 00:00:00 2001 From: Bilal Date: Sat, 30 Nov 2024 02:31:06 +0200 Subject: [PATCH 23/40] Add Cohere provider --- .gitignore | 5 +++- aisuite/providers/cohere_provider.py | 25 +++++++++++++++++ guides/README.md | 3 +- guides/cohere.md | 42 ++++++++++++++++++++++++++++ pyproject.toml | 4 ++- 5 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 aisuite/providers/cohere_provider.py create mode 100644 guides/cohere.md diff --git a/.gitignore b/.gitignore index 718eebba..a0974550 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,7 @@ env/ .coverage # pyenv -.python-version \ No newline at end of file +.python-version + +.DS_Store +**/.DS_Store diff --git a/aisuite/providers/cohere_provider.py b/aisuite/providers/cohere_provider.py new file mode 100644 index 00000000..64fbf2a0 --- /dev/null +++ b/aisuite/providers/cohere_provider.py @@ -0,0 +1,25 @@ +import os +import cohere +from aisuite.provider import Provider + + +class CohereProvider(Provider): + def __init__(self, **config): + """ + Initialize the Cohere provider with the given configuration. + Pass the entire configuration dictionary to the Cohere client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("CO_API_KEY")) + if not config["api_key"]: + raise ValueError( + " API key is missing. Please provide it in the config or set the CO_API_KEY environment variable." + ) + self.client = cohere.ClientV2(**config) + + def chat_completions_create(self, model, messages, **kwargs): + return self.client.chat( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the Cohere API + ) diff --git a/guides/README.md b/guides/README.md index 89f4bd2c..7c2c0622 100644 --- a/guides/README.md +++ b/guides/README.md @@ -2,10 +2,11 @@ These guides give directions for obtaining API keys from different providers. -Here're the instructions for: +Here are the instructions for: - [Anthropic](anthropic.md) - [AWS](aws.md) - [Azure](azure.md) +- [Cohere](cohere.md) - [Google](google.md) - [Hugging Face](huggingface.md) - [OpenAI](openai.md) diff --git a/guides/cohere.md b/guides/cohere.md new file mode 100644 index 00000000..edb115c8 --- /dev/null +++ b/guides/cohere.md @@ -0,0 +1,42 @@ +# Cohere + +To use Cohere with `aisuite`, you’ll need an [Cohere account](https://cohere.com/). After logging in, go to the [API Keys](https://dashboard.cohere.com/api-keyss) section in your account settings, agree to the terms of service, connect your card, and generate a new key. Once you have your key, add it to your environment as follows: + +```shell +export CO_API_KEY="your-cohere-api-key" +``` + +## Create a Chat Completion + +Install the `openai` Python client: + +Example with pip: +```shell +pip install cohere +``` + +Example with poetry: +```shell +poetry add cohere +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +model_id = "command-r-plus-08-2024" + +messages=[ + {"role": "user", "content": "Hi, how are you?"} + ] + +response = client.chat( + model=f"{model_id}", + messages=messages, +) + +print(response) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). diff --git a/pyproject.toml b/pyproject.toml index b6fbfe09..059b39f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ readme = "README.md" python = "^3.10" anthropic = { version = "^0.30.1", optional = true } boto3 = { version = "^1.34.144", optional = true } +cohere = { version = "^5.12.0", optional = true } vertexai = { version = "^1.63.0", optional = true } groq = { version = "^0.9.0", optional = true } mistralai = { version = "^1.0.3", optional = true } @@ -21,6 +22,7 @@ httpx = "~0.27.0" anthropic = ["anthropic"] aws = ["boto3"] azure = [] +cohere = ["cohere"] google = ["vertexai"] groq = ["groq"] huggingface = [] @@ -28,7 +30,7 @@ mistral = ["mistralai"] ollama = [] openai = ["openai"] watsonx = ["ibm-watsonx-ai"] -all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "watsonx"] # To install all providers +all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers [tool.poetry.group.dev.dependencies] pre-commit = "^3.7.1" From 760cf03c101c11fed2d943a15ac4adc77d8f4edb Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:03:41 -0600 Subject: [PATCH 24/40] Update guides/cohere.md --- guides/cohere.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guides/cohere.md b/guides/cohere.md index edb115c8..6ed7bef8 100644 --- a/guides/cohere.md +++ b/guides/cohere.md @@ -1,6 +1,6 @@ # Cohere -To use Cohere with `aisuite`, you’ll need an [Cohere account](https://cohere.com/). After logging in, go to the [API Keys](https://dashboard.cohere.com/api-keyss) section in your account settings, agree to the terms of service, connect your card, and generate a new key. Once you have your key, add it to your environment as follows: +To use Cohere with `aisuite`, you’ll need an [Cohere account](https://cohere.com/). After logging in, go to the [API Keys](https://dashboard.cohere.com/api-keys) section in your account settings, agree to the terms of service, connect your card, and generate a new key. Once you have your key, add it to your environment as follows: ```shell export CO_API_KEY="your-cohere-api-key" From 7c5e7173840419a440877cc017cc3af6f43e72e7 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:03:49 -0600 Subject: [PATCH 25/40] Update guides/cohere.md --- guides/cohere.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guides/cohere.md b/guides/cohere.md index 6ed7bef8..e2f89a15 100644 --- a/guides/cohere.md +++ b/guides/cohere.md @@ -8,7 +8,7 @@ export CO_API_KEY="your-cohere-api-key" ## Create a Chat Completion -Install the `openai` Python client: +Install the `cohere` Python client: Example with pip: ```shell From 18394ed3985c8e2a80078f36bc0f0ef85ef1165c Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:03:59 -0600 Subject: [PATCH 26/40] Update guides/cohere.md --- guides/cohere.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/guides/cohere.md b/guides/cohere.md index e2f89a15..4f7320cf 100644 --- a/guides/cohere.md +++ b/guides/cohere.md @@ -23,20 +23,22 @@ poetry add cohere In your code: ```python import aisuite as ai + client = ai.Client() +provider = "cohere" model_id = "command-r-plus-08-2024" -messages=[ +messages = [ {"role": "user", "content": "Hi, how are you?"} - ] +] -response = client.chat( - model=f"{model_id}", +response = client.chat.completions.create( + model=f"{provider}:{model_id}", messages=messages, ) -print(response) +print(response.choices[0].message.content) ``` Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). From 90642eec1c991a95285c2bd490cfbcb9be4776e9 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:48:40 -0600 Subject: [PATCH 27/40] add normalize_response to match openai response format --- aisuite/providers/cohere_provider.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/aisuite/providers/cohere_provider.py b/aisuite/providers/cohere_provider.py index 64fbf2a0..295dff1a 100644 --- a/aisuite/providers/cohere_provider.py +++ b/aisuite/providers/cohere_provider.py @@ -1,5 +1,7 @@ import os import cohere + +from aisuite.framework import ChatCompletionResponse from aisuite.provider import Provider @@ -18,8 +20,17 @@ def __init__(self, **config): self.client = cohere.ClientV2(**config) def chat_completions_create(self, model, messages, **kwargs): - return self.client.chat( + response = self.client.chat( model=model, messages=messages, **kwargs # Pass any additional arguments to the Cohere API ) + + return self.normalize_response(response) + + + def normalize_response(self, response): + """Normalize the reponse from Cohere API to match OpenAI's response format.""" + normalized_response = ChatCompletionResponse() + normalized_response.choices[0].message.content = response.message.content[0].text + return normalized_response \ No newline at end of file From e6175300f52674143328a40aa1f369f122ecd862 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:49:32 -0600 Subject: [PATCH 28/40] update lock file with cohere --- poetry.lock | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 744b0425..c55d3a27 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohttp" @@ -838,6 +838,33 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cohere" +version = "5.13.3" +description = "" +optional = true +python-versions = "<4.0,>=3.9" +files = [ + {file = "cohere-5.13.3-py3-none-any.whl", hash = "sha256:076c88fdd3d670b6577eb8e813a9072bf18b59648d4092c6f0263af3c27bf81f"}, + {file = "cohere-5.13.3.tar.gz", hash = "sha256:70d87e0d5ce48aaee5ba70ead5efbade226cb2a4b11bfcfb676f6a2db3642819"}, +] + +[package.dependencies] +fastavro = ">=1.9.4,<2.0.0" +httpx = ">=0.21.2" +httpx-sse = "0.4.0" +numpy = ">=1.26,<2.0" +parameterized = ">=0.9.0,<0.10.0" +pydantic = ">=1.9.2" +pydantic-core = ">=2.18.2,<3.0.0" +requests = ">=2.0.0,<3.0.0" +tokenizers = ">=0.15,<1" +types-requests = ">=2.0.0,<3.0.0" +typing_extensions = ">=4.0.0" + +[package.extras] +aws = ["boto3 (>=1.34.0,<2.0.0)", "sagemaker (>=2.232.1,<3.0.0)"] + [[package]] name = "colorama" version = "0.4.6" @@ -1227,6 +1254,52 @@ typer = ">=0.12.3" [package.extras] standard = ["fastapi", "uvicorn[standard] (>=0.15.0)"] +[[package]] +name = "fastavro" +version = "1.9.7" +description = "Fast read/write of AVRO files" +optional = true +python-versions = ">=3.8" +files = [ + {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb8749e419a85f251bf1ac87d463311874972554d25d4a0b19f6bdc56036d7cf"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b2f9bafa167cb4d1c3dd17565cb5bf3d8c0759e42620280d1760f1e778e07fc"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e87d04b235b29f7774d226b120da2ca4e60b9e6fdf6747daef7f13f218b3517a"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b525c363e267ed11810aaad8fbdbd1c3bd8837d05f7360977d72a65ab8c6e1fa"}, + {file = "fastavro-1.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:6312fa99deecc319820216b5e1b1bd2d7ebb7d6f221373c74acfddaee64e8e60"}, + {file = "fastavro-1.9.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ec8499dc276c2d2ef0a68c0f1ad11782b2b956a921790a36bf4c18df2b8d4020"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d9d96f98052615ab465c63ba8b76ed59baf2e3341b7b169058db104cbe2aa0"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919f3549e07a8a8645a2146f23905955c35264ac809f6c2ac18142bc5b9b6022"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9de1fa832a4d9016724cd6facab8034dc90d820b71a5d57c7e9830ffe90f31e4"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1d09227d1f48f13281bd5ceac958650805aef9a4ef4f95810128c1f9be1df736"}, + {file = "fastavro-1.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:2db993ae6cdc63e25eadf9f93c9e8036f9b097a3e61d19dca42536dcc5c4d8b3"}, + {file = "fastavro-1.9.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4e1289b731214a7315884c74b2ec058b6e84380ce9b18b8af5d387e64b18fc44"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eac69666270a76a3a1d0444f39752061195e79e146271a568777048ffbd91a27"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9be089be8c00f68e343bbc64ca6d9a13e5e5b0ba8aa52bcb231a762484fb270e"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d576eccfd60a18ffa028259500df67d338b93562c6700e10ef68bbd88e499731"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ee9bf23c157bd7dcc91ea2c700fa3bd924d9ec198bb428ff0b47fa37fe160659"}, + {file = "fastavro-1.9.7-cp312-cp312-win_amd64.whl", hash = "sha256:b6b2ccdc78f6afc18c52e403ee68c00478da12142815c1bd8a00973138a166d0"}, + {file = "fastavro-1.9.7-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:7313def3aea3dacface0a8b83f6d66e49a311149aa925c89184a06c1ef99785d"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:536f5644737ad21d18af97d909dba099b9e7118c237be7e4bd087c7abde7e4f0"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2af559f30383b79cf7d020a6b644c42ffaed3595f775fe8f3d7f80b1c43dfdc5"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:edc28ab305e3c424de5ac5eb87b48d1e07eddb6aa08ef5948fcda33cc4d995ce"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ec2e96bdabd58427fe683329b3d79f42c7b4f4ff6b3644664a345a655ac2c0a1"}, + {file = "fastavro-1.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:3b683693c8a85ede496ebebe115be5d7870c150986e34a0442a20d88d7771224"}, + {file = "fastavro-1.9.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:58f76a5c9a312fbd37b84e49d08eb23094d36e10d43bc5df5187bc04af463feb"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56304401d2f4f69f5b498bdd1552c13ef9a644d522d5de0dc1d789cf82f47f73"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fcce036c6aa06269fc6a0428050fcb6255189997f5e1a728fc461e8b9d3e26b"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:17de68aae8c2525f5631d80f2b447a53395cdc49134f51b0329a5497277fc2d2"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7c911366c625d0a997eafe0aa83ffbc6fd00d8fd4543cb39a97c6f3b8120ea87"}, + {file = "fastavro-1.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:912283ed48578a103f523817fdf0c19b1755cea9b4a6387b73c79ecb8f8f84fc"}, + {file = "fastavro-1.9.7.tar.gz", hash = "sha256:13e11c6cb28626da85290933027cd419ce3f9ab8e45410ef24ce6b89d20a1f6c"}, +] + +[package.extras] +codecs = ["cramjam", "lz4", "zstandard"] +lz4 = ["lz4"] +snappy = ["cramjam"] +zstandard = ["zstandard"] + [[package]] name = "fastjsonschema" version = "2.20.0" @@ -3855,6 +3928,20 @@ files = [ {file = "pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e"}, ] +[[package]] +name = "parameterized" +version = "0.9.0" +description = "Parameterized testing with any Python test framework" +optional = true +python-versions = ">=3.7" +files = [ + {file = "parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b"}, + {file = "parameterized-0.9.0.tar.gz", hash = "sha256:7fc905272cefa4f364c1a3429cbbe9c0f98b793988efb5bf90aac80f08db09b1"}, +] + +[package.extras] +dev = ["jinja2"] + [[package]] name = "parso" version = "0.8.4" @@ -5940,6 +6027,20 @@ files = [ {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, ] +[[package]] +name = "types-requests" +version = "2.32.0.20241016" +description = "Typing stubs for requests" +optional = true +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.12.2" @@ -6671,10 +6772,11 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "groq", "openai"] +all = ["anthropic", "cohere", "groq", "openai"] anthropic = ["anthropic"] aws = ["boto3"] azure = [] +cohere = ["cohere"] google = ["vertexai"] groq = ["groq"] huggingface = [] @@ -6686,4 +6788,4 @@ watsonx = ["ibm-watsonx-ai"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8a811ea383ced8df202319530f8c910fee564a8fdb11dc655348c49ad661ab04" +content-hash = "522cba517a2a3cc94bdee88c3af849d76bc00dc9d777d38a40cda860e7c108cb" From 9b2eb9d50c8cc89f083254ca9b519db517f4b873 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 23:03:33 -0600 Subject: [PATCH 29/40] add tests for cohere provider --- tests/providers/test_cohere_provider.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/providers/test_cohere_provider.py diff --git a/tests/providers/test_cohere_provider.py b/tests/providers/test_cohere_provider.py new file mode 100644 index 00000000..d7e10486 --- /dev/null +++ b/tests/providers/test_cohere_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.cohere_provider import CohereProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("CO_API_KEY", "test-api-key") + + +def test_cohere_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = CohereProvider() + mock_response = MagicMock() + mock_response.message = MagicMock() + mock_response.message.content = [MagicMock()] + mock_response.message.content[0].text = response_text_content + + with patch.object( + provider.client, + "chat", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content From 770ba4039cbbe57038e0f1a8f32977716b35e790 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 23:07:16 -0600 Subject: [PATCH 30/40] black formatting --- aisuite/providers/cohere_provider.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aisuite/providers/cohere_provider.py b/aisuite/providers/cohere_provider.py index 295dff1a..5886f24b 100644 --- a/aisuite/providers/cohere_provider.py +++ b/aisuite/providers/cohere_provider.py @@ -28,9 +28,10 @@ def chat_completions_create(self, model, messages, **kwargs): return self.normalize_response(response) - def normalize_response(self, response): """Normalize the reponse from Cohere API to match OpenAI's response format.""" normalized_response = ChatCompletionResponse() - normalized_response.choices[0].message.content = response.message.content[0].text - return normalized_response \ No newline at end of file + normalized_response.choices[0].message.content = response.message.content[ + 0 + ].text + return normalized_response From cd32d5388dde8cf67ccfd6c3390a7a70227c2a79 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 23:14:20 -0600 Subject: [PATCH 31/40] update gh action to install all extras --- .github/workflows/run_pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index 9630faf7..363d8ce7 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -18,7 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install poetry - poetry install --with test + poetry install --all-extras - name: Test with pytest run: poetry run pytest From c46942339dd6025eb79400410742968e0dfc89e6 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Mon, 9 Dec 2024 23:18:27 -0600 Subject: [PATCH 32/40] update gh action to install test dependencies --- .github/workflows/run_pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index 363d8ce7..f8290c08 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -18,7 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install poetry - poetry install --all-extras + poetry install --all-extras --with test - name: Test with pytest run: poetry run pytest From 271af0df2966c26785e94c8a937d706f0b32c9f7 Mon Sep 17 00:00:00 2001 From: gautam-goudar <131703167+gautam-goudar@users.noreply.github.com> Date: Fri, 13 Dec 2024 20:12:17 -0600 Subject: [PATCH 33/40] Update google.md (#149) The prefix has to be "google" for this to be working. This is for the "aisuite" package version:0.1.6 --- guides/google.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guides/google.md b/guides/google.md index 6dcaacf3..e357679e 100644 --- a/guides/google.md +++ b/guides/google.md @@ -74,7 +74,7 @@ In your code: import aisuite as ai client = ai.Client() -model="vertex:gemini-1.5-pro-001" +model="google:gemini-1.5-pro-001" messages = [ {"role": "system", "content": "Respond in Pirate English."}, From 5b83ec00f9c5461fa676b35f54404b6b8af06e1a Mon Sep 17 00:00:00 2001 From: Rohit Prasad Date: Thu, 26 Dec 2024 14:20:18 -0800 Subject: [PATCH 34/40] Test multiple providers. (#165) This will be invoked by the GitHub workflow each time a release is cut. This is part of the pre-work for creating release automation. Ignore the integration test using mark. --- .github/workflows/run_pytest.yml | 2 +- pyproject.toml | 5 +-- tests/client/test_prerelease.py | 74 ++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 4 deletions(-) create mode 100644 tests/client/test_prerelease.py diff --git a/.github/workflows/run_pytest.yml b/.github/workflows/run_pytest.yml index f8290c08..d873172b 100644 --- a/.github/workflows/run_pytest.yml +++ b/.github/workflows/run_pytest.yml @@ -20,5 +20,5 @@ jobs: pip install poetry poetry install --all-extras --with test - name: Test with pytest - run: poetry run pytest + run: poetry run pytest -m "not integration" diff --git a/pyproject.toml b/pyproject.toml index 059b39f4..145af415 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,6 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] testpaths="tests" -addopts=[ - "--cov=aisuite", - "--cov-report=term-missing" +markers = [ + "integration: marks tests as integration tests that interact with external services", ] diff --git a/tests/client/test_prerelease.py b/tests/client/test_prerelease.py new file mode 100644 index 00000000..bb5f3285 --- /dev/null +++ b/tests/client/test_prerelease.py @@ -0,0 +1,74 @@ +# Run this test before releasing a new version. +# It will test all the models in the client. + +import pytest +import aisuite as ai +from typing import List, Dict +from dotenv import load_dotenv, find_dotenv + + +def setup_client() -> ai.Client: + """Initialize the AI client with environment variables.""" + load_dotenv(find_dotenv()) + return ai.Client() + + +def get_test_models() -> List[str]: + """Return a list of model identifiers to test.""" + return [ + "anthropic:claude-3-5-sonnet-20240620", + "aws:meta.llama3-1-8b-instruct-v1:0", + "huggingface:mistralai/Mistral-7B-Instruct-v0.3", + "groq:llama3-8b-8192", + "mistral:open-mistral-7b", + "openai:gpt-3.5-turbo", + "cohere:command-r-plus-08-2024", + ] + + +def get_test_messages() -> List[Dict[str, str]]: + """Return the test messages to send to each model.""" + return [ + { + "role": "system", + "content": "Respond in Pirate English. Always try to include the phrase - No rum No fun.", + }, + {"role": "user", "content": "Tell me a joke about Captain Jack Sparrow"}, + ] + + +@pytest.mark.integration +@pytest.mark.parametrize("model_id", get_test_models()) +def test_model_pirate_response(model_id: str): + """ + Test that each model responds appropriately to the pirate prompt. + + Args: + model_id: The provider:model identifier to test + """ + client = setup_client() + messages = get_test_messages() + + try: + response = client.chat.completions.create( + model=model_id, messages=messages, temperature=0.75 + ) + + content = response.choices[0].message.content.lower() + + # Check if either version of the required phrase is present + assert any( + phrase in content for phrase in ["no rum no fun", "no rum, no fun"] + ), f"Model {model_id} did not include required phrase 'No rum No fun'" + + assert len(content) > 0, f"Model {model_id} returned empty response" + assert isinstance( + content, str + ), f"Model {model_id} returned non-string response" + + except Exception as e: + pytest.fail(f"Error testing model {model_id}: {str(e)}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 80890d901317c4b6cdbf166f2da299167f3f6caf Mon Sep 17 00:00:00 2001 From: Rohit Prasad Date: Thu, 26 Dec 2024 14:35:33 -0800 Subject: [PATCH 35/40] Bumping version from 0.1.6 to 0.1.7 (#166) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 145af415..e96e6216 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aisuite" -version = "0.1.6" +version = "0.1.7" description = "Uniform access layer for LLMs" authors = ["Andrew Ng, Rohit P"] readme = "README.md" From 66afbdd5018eb1c567edb386e0e761e65cbf183c Mon Sep 17 00:00:00 2001 From: Akim Tsvigun Date: Thu, 2 Jan 2025 11:10:36 +0100 Subject: [PATCH 36/40] Nebius AI Studio brief documentation added --- guides/nebius.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 guides/nebius.md diff --git a/guides/nebius.md b/guides/nebius.md new file mode 100644 index 00000000..2343d503 --- /dev/null +++ b/guides/nebius.md @@ -0,0 +1,44 @@ +# Nebius AI Studio + +To use Nebius AI Studio with `aisuite`, you need an AI Studio account. Go to [AI Studio](https://studio.nebius.ai/) and press "Log in to AI Studio" in the right top corner. After logging in, go to the [API Keys](https://studio.nebius.ai/settings/api-keys) section and generate a new key. Once you have a key, add it to your environment as follows: + +```shell +export NEBIUS_API_KEY="your-nebius-api-key" +``` + +## Create a Chat Completion + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "nebius" +model_id = "meta-llama/Llama-3.3-70B-Instruct" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "How many times has Jurgen Klopp won the Champions League?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). From 6f114e19221f444f74e54c6ab177f30bd2a752f1 Mon Sep 17 00:00:00 2001 From: Riddhimaan-Senapati Date: Thu, 9 Jan 2025 00:03:35 +0530 Subject: [PATCH 37/40] added basic code for deepseek provider in aisuite using the openai provider code as inspiration. --- aisuite/providers/deepseek_provider.py | 34 ++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 aisuite/providers/deepseek_provider.py diff --git a/aisuite/providers/deepseek_provider.py b/aisuite/providers/deepseek_provider.py new file mode 100644 index 00000000..7bc7f2a6 --- /dev/null +++ b/aisuite/providers/deepseek_provider.py @@ -0,0 +1,34 @@ +import openai +import os +from aisuite.provider import Provider + + +class OpenaiProvider(Provider): + def __init__(self, **config): + """ + Initialize the DeepSeek provider with the given configuration. + Pass the entire configuration dictionary to the OpenAI client constructor. + """ + # Ensure API key is provided either in config or via environment variable + config.setdefault("api_key", os.getenv("DEEPSEEK_API_KEY")) + if not config["api_key"]: + raise ValueError( + "DeepSeek API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." + ) + config["base_url"]="https://api.deepseek.com" + + # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically + # infer certain values from the environment variables. + # Eg: OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID. Except for OPEN_AI_BASE_URL which has to be the deepseek url + + # Pass the entire config to the OpenAI client constructor + self.client = openai.OpenAI(**config) + + def chat_completions_create(self, model, messages, **kwargs): + # Any exception raised by OpenAI will be returned to the caller. + # Maybe we should catch them and raise a custom LLMError. + return self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs # Pass any additional arguments to the OpenAI API + ) From f4f51d85897ffad55a3b5971a84bb4c22f5742ea Mon Sep 17 00:00:00 2001 From: Riddhimaan-Senapati Date: Thu, 9 Jan 2025 00:07:06 +0530 Subject: [PATCH 38/40] added tests for deepseek provider as well. --- aisuite/providers/deepseek_provider.py | 2 +- tests/providers/test_deepseek_provider.py | 46 +++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 tests/providers/test_deepseek_provider.py diff --git a/aisuite/providers/deepseek_provider.py b/aisuite/providers/deepseek_provider.py index 7bc7f2a6..8ba7c5b3 100644 --- a/aisuite/providers/deepseek_provider.py +++ b/aisuite/providers/deepseek_provider.py @@ -3,7 +3,7 @@ from aisuite.provider import Provider -class OpenaiProvider(Provider): +class DeepseekProvider(Provider): def __init__(self, **config): """ Initialize the DeepSeek provider with the given configuration. diff --git a/tests/providers/test_deepseek_provider.py b/tests/providers/test_deepseek_provider.py new file mode 100644 index 00000000..1ab6f1c1 --- /dev/null +++ b/tests/providers/test_deepseek_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.deepseek_provider import DeepseekProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("DEEPSEEK_API_KEY", "test-api-key") + + +def test_groq_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = DeepseekProvider() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = response_text_content + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content From 4e429e07f187d9b8bf8783bb13cff9051202e7ea Mon Sep 17 00:00:00 2001 From: Riddhimaan-Senapati Date: Fri, 10 Jan 2025 08:35:32 +0530 Subject: [PATCH 39/40] added a guide on how to use DeepSeek LLMs with aisuite --- aisuite/providers/deepseek_provider.py | 2 +- guides/README.md | 1 + guides/deepseek.md | 46 ++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 guides/deepseek.md diff --git a/aisuite/providers/deepseek_provider.py b/aisuite/providers/deepseek_provider.py index 8ba7c5b3..40c6e451 100644 --- a/aisuite/providers/deepseek_provider.py +++ b/aisuite/providers/deepseek_provider.py @@ -1,6 +1,6 @@ import openai import os -from aisuite.provider import Provider +from aisuite.provider import Provider, LLMError class DeepseekProvider(Provider): diff --git a/guides/README.md b/guides/README.md index 7c2c0622..50774586 100644 --- a/guides/README.md +++ b/guides/README.md @@ -12,6 +12,7 @@ Here are the instructions for: - [OpenAI](openai.md) - [SambaNova](sambanova.md) - [xAI](xai.md) +- [DeepSeek](deepseek.md) Unless otherwise stated, these guides have not been endorsed by the providers. diff --git a/guides/deepseek.md b/guides/deepseek.md new file mode 100644 index 00000000..9985a11f --- /dev/null +++ b/guides/deepseek.md @@ -0,0 +1,46 @@ +# DeepSeek + +To use DeepSeek with `aisuite`, you’ll need an [DeepSeek account](https://platform.deepseek.com). After logging in, go to the [API Keys](https://platform.deepseek.com/api_keys) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows: + +```shell +export DEEPSEEK_API_KEY="your-deepseek-api-key" +``` + +## Create a Chat Completion + +(Note: The DeepSeek uses an API format consistent with OpenAI, hence why we need to install OpenAI, there is no DeepSeek Library at least not for now) + +Install the `openai` Python client: + +Example with pip: +```shell +pip install openai +``` + +Example with poetry: +```shell +poetry add openai +``` + +In your code: +```python +import aisuite as ai +client = ai.Client() + +provider = "deepseek" +model_id = "deepseek-chat" + +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What’s the weather like in San Francisco?"}, +] + +response = client.chat.completions.create( + model=f"{provider}:{model_id}", + messages=messages, +) + +print(response.choices[0].message.content) +``` + +Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md). From abe10594f56aba3d4092cd409082020aa0692fe7 Mon Sep 17 00:00:00 2001 From: Riddhimaan-Senapati Date: Sun, 19 Jan 2025 15:32:55 -0500 Subject: [PATCH 40/40] reformated the code using black as requested. --- aisuite/providers/deepseek_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aisuite/providers/deepseek_provider.py b/aisuite/providers/deepseek_provider.py index 40c6e451..16327c57 100644 --- a/aisuite/providers/deepseek_provider.py +++ b/aisuite/providers/deepseek_provider.py @@ -15,7 +15,7 @@ def __init__(self, **config): raise ValueError( "DeepSeek API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." ) - config["base_url"]="https://api.deepseek.com" + config["base_url"] = "https://api.deepseek.com" # NOTE: We could choose to remove above lines for api_key since OpenAI will automatically # infer certain values from the environment variables.