Skip to content

Commit 3d31443

Browse files
authored
Merge pull request #7 from andrewyng/fireworks_replicate
added providers fireworks and replicate with test code and outputs
2 parents 2f89367 + 3bdd7fe commit 3d31443

File tree

6 files changed

+211
-52
lines changed

6 files changed

+211
-52
lines changed

.env.sample

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
ANTHROPIC_API_KEY=""
2+
FIREWORKS_API_KEY=""
23
GROQ_API_KEY=""
34
MISTRAL_API_KEY=""
45
OPENAI_API_KEY=""
56
OLLAMA_API_URL="http://localhost:11434"
7+
REPLICATE_API_KEY=""

aimodels/client/multi_fm_client.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from .chat import Chat
44
from ..providers import (
55
AnthropicInterface,
6+
FireworksInterface,
67
GroqInterface,
78
MistralInterface,
89
OllamaInterface,
910
OpenAIInterface,
11+
ReplicateInterface,
1012
)
1113

1214

@@ -34,10 +36,12 @@ def __init__(self):
3436
self.all_interfaces = {}
3537
self.all_factories = {
3638
"anthropic": AnthropicInterface,
39+
"fireworks": FireworksInterface,
3740
"groq": GroqInterface,
3841
"mistral": MistralInterface,
3942
"ollama": OllamaInterface,
4043
"openai": OpenAIInterface,
44+
"replicate": ReplicateInterface,
4145
}
4246

4347
def get_provider_interface(self, model):

aimodels/providers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Provides the individual provider interfaces for each FM provider."""
22

33
from .anthropic_interface import AnthropicInterface
4+
from .fireworks_interface import FireworksInterface
45
from .groq_interface import GroqInterface
56
from .mistral_interface import MistralInterface
67
from .ollama_interface import OllamaInterface
78
from .openai_interface import OpenAIInterface
9+
from .replicate_interface import ReplicateInterface
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""The interface to the Fireworks API."""
2+
3+
import os
4+
5+
from ..framework.provider_interface import ProviderInterface
6+
7+
8+
class FireworksInterface(ProviderInterface):
9+
"""Implements the ProviderInterface for interacting with Fireworks's APIs."""
10+
11+
def __init__(self):
12+
"""Set up the Fireworks client using the API key obtained from the user's environment."""
13+
from fireworks.client import Fireworks
14+
15+
self.fireworks_client = Fireworks(api_key=os.getenv("FIREWORKS_API_KEY"))
16+
17+
def chat_completion_create(self, messages=None, model=None, temperature=0):
18+
"""Request chat completions from the Fireworks API.
19+
20+
Args:
21+
----
22+
model (str): Identifies the specific provider/model to use.
23+
messages (list of dict): A list of message objects in chat history.
24+
temperature (float): The temperature to use in the completion.
25+
26+
Returns:
27+
-------
28+
The API response with the completion result.
29+
30+
"""
31+
return self.fireworks_client.chat.completions.create(
32+
model=model,
33+
messages=messages,
34+
temperature=temperature,
35+
)
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""The interface to the Replicate API."""
2+
3+
import os
4+
5+
from ..framework.provider_interface import ProviderInterface
6+
7+
_REPLICATE_BASE_URL = "https://openai-proxy.replicate.com/v1"
8+
9+
10+
class ReplicateInterface(ProviderInterface):
11+
"""Implements the ProviderInterface for interacting with Replicate's APIs."""
12+
13+
def __init__(self):
14+
"""Set up the Replicate client using the API key obtained from the user's environment."""
15+
from openai import OpenAI
16+
17+
self.replicate_client = OpenAI(
18+
api_key=os.getenv("REPLICATE_API_KEY"),
19+
base_url=_REPLICATE_BASE_URL,
20+
)
21+
22+
def chat_completion_create(self, messages=None, model=None, temperature=0):
23+
"""Request chat completions from the Replicate API.
24+
25+
Args:
26+
----
27+
model (str): Identifies the specific provider/model to use.
28+
messages (list of dict): A list of message objects in chat history.
29+
temperature (float): The temperature to use in the completion.
30+
31+
Returns:
32+
-------
33+
The API response with the completion result.
34+
35+
"""
36+
return self.replicate_client.chat.completions.create(
37+
model=model,
38+
messages=messages,
39+
temperature=temperature,
40+
)

0 commit comments

Comments
 (0)