Skip to content

Commit c5a57d6

Browse files
Merge branch 'main' into rp_AzureDocs
2 parents 3a6ac56 + 2e36536 commit c5a57d6

File tree

7 files changed

+113
-126
lines changed

7 files changed

+113
-126
lines changed

aisuite/provider.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ProviderNames(str, Enum):
2424
GROQ = "groq"
2525
GOOGLE = "google"
2626
MISTRAL = "mistral"
27+
OLLAMA = "ollama"
2728
OPENAI = "openai"
2829

2930

@@ -46,6 +47,7 @@ class ProviderFactory:
4647
"aisuite.providers.mistral_provider",
4748
"MistralProvider",
4849
),
50+
ProviderNames.OLLAMA: ("aisuite.providers.ollama_provider", "OllamaProvider"),
4951
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"),
5052
}
5153

aisuite/providers/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@
22

33
from .fireworks_interface import FireworksInterface
44
from .octo_interface import OctoInterface
5-
from .ollama_interface import OllamaInterface
65
from .replicate_interface import ReplicateInterface
76
from .together_interface import TogetherInterface

aisuite/providers/ollama_interface.py

-54
This file was deleted.

aisuite/providers/ollama_provider.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import os
2+
import httpx
3+
from aisuite.provider import Provider, LLMError
4+
from aisuite.framework import ChatCompletionResponse
5+
6+
7+
class OllamaProvider(Provider):
8+
"""
9+
Ollama Provider that makes HTTP calls instead of using SDK.
10+
It uses the /api/chat endpoint.
11+
Read more here - https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
12+
If OLLAMA_API_URL is not set and not passed in config, then it will default to "http://localhost:11434"
13+
"""
14+
15+
_CHAT_COMPLETION_ENDPOINT = "/api/chat"
16+
_CONNECT_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host."
17+
18+
def __init__(self, **config):
19+
"""
20+
Initialize the Ollama provider with the given configuration.
21+
"""
22+
self.url = config.get("api_url") or os.getenv(
23+
"OLLAMA_API_URL", "http://localhost:11434"
24+
)
25+
26+
# Optionally set a custom timeout (default to 30s)
27+
self.timeout = config.get("timeout", 30)
28+
29+
def chat_completions_create(self, model, messages, **kwargs):
30+
"""
31+
Makes a request to the chat completions endpoint using httpx.
32+
"""
33+
kwargs["stream"] = False
34+
data = {
35+
"model": model,
36+
"messages": messages,
37+
**kwargs, # Pass any additional arguments to the API
38+
}
39+
40+
try:
41+
response = httpx.post(
42+
self.url.rstrip("/") + self._CHAT_COMPLETION_ENDPOINT,
43+
json=data,
44+
timeout=self.timeout,
45+
)
46+
response.raise_for_status()
47+
except httpx.ConnectError: # Handle connection errors
48+
raise LLMError(f"Connection failed: {self._CONNECT_ERROR_MESSAGE}")
49+
except httpx.HTTPStatusError as http_err:
50+
raise LLMError(f"Ollama request failed: {http_err}")
51+
except Exception as e:
52+
raise LLMError(f"An error occurred: {e}")
53+
54+
# Return the normalized response
55+
return self._normalize_response(response.json())
56+
57+
def _normalize_response(self, response_data):
58+
"""
59+
Normalize the API response to a common format (ChatCompletionResponse).
60+
"""
61+
normalized_response = ChatCompletionResponse()
62+
normalized_response.choices[0].message.content = response_data["message"][
63+
"content"
64+
]
65+
return normalized_response

examples/client.ipynb

+1-16
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,6 @@
135135
"print(response.choices[0].message.content)"
136136
]
137137
},
138-
{
139-
"cell_type": "code",
140-
"execution_count": null,
141-
"id": "4b3e6c41-070d-4041-9ed9-c8977790fe18",
142-
"metadata": {},
143-
"outputs": [],
144-
"source": [
145-
"together_llama3_8b = \"together:meta-llama/Llama-3-8b-chat-hf\"\n",
146-
"#together_llama3_70b = \"together:meta-llama/Llama-3-70b-chat-hf\"\n",
147-
"\n",
148-
"response = client.chat.completions.create(model=together_llama3_8b, messages=messages)\n",
149-
"\n",
150-
"print(response.choices[0].message.content)"
151-
]
152-
},
153138
{
154139
"cell_type": "code",
155140
"execution_count": null,
@@ -168,7 +153,7 @@
168153
{
169154
"cell_type": "code",
170155
"execution_count": null,
171-
"id": "19cdb1ab",
156+
"id": "6819ac17",
172157
"metadata": {},
173158
"outputs": [],
174159
"source": [

tests/providers/test_ollama_interface.py

-55
This file was deleted.
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
from unittest.mock import patch, MagicMock
3+
from aisuite.providers.ollama_provider import OllamaProvider
4+
5+
6+
@pytest.fixture(autouse=True)
7+
def set_api_url_var(monkeypatch):
8+
"""Fixture to set environment variables for tests."""
9+
monkeypatch.setenv("OLLAMA_API_URL", "http://localhost:11434")
10+
11+
12+
def test_completion():
13+
"""Test that completions request successfully."""
14+
15+
user_greeting = "Howdy!"
16+
message_history = [{"role": "user", "content": user_greeting}]
17+
selected_model = "best-model-ever"
18+
chosen_temperature = 0.77
19+
response_text_content = "mocked-text-response-from-ollama-model"
20+
21+
ollama = OllamaProvider()
22+
mock_response = {"message": {"content": response_text_content}}
23+
24+
with patch(
25+
"httpx.post",
26+
return_value=MagicMock(status_code=200, json=lambda: mock_response),
27+
) as mock_post:
28+
response = ollama.chat_completions_create(
29+
messages=message_history,
30+
model=selected_model,
31+
temperature=chosen_temperature,
32+
)
33+
34+
mock_post.assert_called_once_with(
35+
"http://localhost:11434/api/chat",
36+
json={
37+
"model": selected_model,
38+
"messages": message_history,
39+
"stream": False,
40+
"temperature": chosen_temperature,
41+
},
42+
timeout=30,
43+
)
44+
45+
assert response.choices[0].message.content == response_text_content

0 commit comments

Comments
 (0)