Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added providers fireworks and replicate with test code and outputs #7

Merged
merged 3 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
ANTHROPIC_API_KEY=""
FIREWORKS_API_KEY=""
GROQ_API_KEY=""
MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
4 changes: 4 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .chat import Chat
from ..providers import (
AnthropicInterface,
FireworksInterface,
GroqInterface,
MistralInterface,
OllamaInterface,
OpenAIInterface,
ReplicateInterface,
)


Expand Down Expand Up @@ -34,10 +36,12 @@ def __init__(self):
self.all_interfaces = {}
self.all_factories = {
"anthropic": AnthropicInterface,
"fireworks": FireworksInterface,
"groq": GroqInterface,
"mistral": MistralInterface,
"ollama": OllamaInterface,
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
}

def get_provider_interface(self, model):
Expand Down
2 changes: 2 additions & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Provides the individual provider interfaces for each FM provider."""

from .anthropic_interface import AnthropicInterface
from .fireworks_interface import FireworksInterface
from .groq_interface import GroqInterface
from .mistral_interface import MistralInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
35 changes: 35 additions & 0 deletions aimodels/providers/fireworks_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""The interface to the Fireworks API."""

import os

from ..framework.provider_interface import ProviderInterface


class FireworksInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Fireworks's APIs."""

def __init__(self):
"""Set up the Fireworks client using the API key obtained from the user's environment."""
from fireworks.client import Fireworks

self.fireworks_client = Fireworks(api_key=os.getenv("FIREWORKS_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Fireworks API.

Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.

Returns:
-------
The API response with the completion result.

"""
return self.fireworks_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
40 changes: 40 additions & 0 deletions aimodels/providers/replicate_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Replicate API."""

import os

from ..framework.provider_interface import ProviderInterface

_REPLICATE_BASE_URL = "https://openai-proxy.replicate.com/v1"


class ReplicateInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Replicate's APIs."""

def __init__(self):
"""Set up the Replicate client using the API key obtained from the user's environment."""
from openai import OpenAI

self.replicate_client = OpenAI(
api_key=os.getenv("REPLICATE_API_KEY"),
base_url=_REPLICATE_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Replicate API.

Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.

Returns:
-------
The API response with the completion result.

"""
return self.replicate_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
Loading
Loading