Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for reka provider #180

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ XAI_API_KEY=

# Sambanova
SAMBANOVA_API_KEY=

# DeepSeek
DeepSeek_API_KEY=

# Reka
Reka_API_KEY=
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ env/

# pyenv
.python-version
.venv

.DS_Store
**/.DS_Store
29 changes: 29 additions & 0 deletions aisuite/providers/reka_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from reka.client import Reka
import os
from aisuite.provider import Provider, LLMError


class RekaProvider(Provider):
def __init__(self, **config):
"""
Initialize the Reka provider with the given configuration.
Pass the entire configuration dictionary to the Reka client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("REKA_API_KEY"))
if not config["api_key"]:
raise ValueError(
"Reka API key is missing. Please provide it in the config or set the REKA_API_KEY environment variable."
)

# Pass the entire config to the Reka client constructor
self.client = Reka(**config)

def chat_completions_create(self, model, messages, **kwargs):
# Any exception raised by Reka will be returned to the caller.
# Maybe we should catch them and raise a custom LLMError.
return self.client.chat.create(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the Reka API
)
1 change: 1 addition & 0 deletions guides/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Here are the instructions for:
- [SambaNova](sambanova.md)
- [xAI](xai.md)
- [DeepSeek](deepseek.md)
- [Reka](reka.md)

Unless otherwise stated, these guides have not been endorsed by the providers.

Expand Down
44 changes: 44 additions & 0 deletions guides/reka.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Reka

To use Reka with `aisuite`, you’ll need an [Reka account](https://platform.reka.ai/). After logging in, go to the [API Keys](https://platform.reka.ai/apikeys) section in your account settings and generate a new key. Once you have your key, add it to your environment as follows:

```shell
export REKA_API_KEY="your-reka-api-key"
```

## Create a Chat Completion

Install the `reka-api` Python client:

Example with pip:
```shell
pip install reka-api
```

Example with poetry:
```shell
poetry add reka-api
```

In your code:
```python
import aisuite as ai
client = ai.Client()

provider = "reka"
model_id = "reka-core"

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What’s the weather like in San Francisco?"},
]

response = client.chat.completions.create(
model=f"{provider}:{model_id}",
messages=messages,
)

print(response.choices[0].message.content)
```

Happy coding! If you’d like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md).
560 changes: 553 additions & 7 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ groq = { version = "^0.9.0", optional = true }
mistralai = { version = "^1.0.3", optional = true }
openai = { version = "^1.35.8", optional = true }
ibm-watsonx-ai = { version = "^1.1.16", optional = true }
reka = { version = "^3.2.0", optional = true }

# Optional dependencies for different providers
httpx = "~0.27.0"
Expand All @@ -30,7 +31,8 @@ mistral = ["mistralai"]
ollama = []
openai = ["openai"]
watsonx = ["ibm-watsonx-ai"]
all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers
reka=["reka-api"]
all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cohere", "watsonx","reka"] # To install all providers

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
Expand All @@ -49,6 +51,7 @@ sentence-transformers = "^3.0.1"
datasets = "^2.20.0"
vertexai = "^1.63.0"
ibm-watsonx-ai = "^1.1.16"
reka-api = "^3.2.0"

[tool.poetry.group.test]
optional = true
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/test_deepseek_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def set_api_key_env_var(monkeypatch):
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-api-key")


def test_groq_provider():
def test_deepseek_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
Expand Down
46 changes: 46 additions & 0 deletions tests/providers/test_reka_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from unittest.mock import MagicMock, patch

import pytest

from aisuite.providers.reka_provider import RekaProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("REKA_API_KEY", "test-api-key")


def test_groq_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "reka-core" # just for an example
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = RekaProvider()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content