-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathreplicate_interface.py
40 lines (29 loc) · 1.24 KB
/
replicate_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""The interface to the Replicate API."""
import os
from ..framework.provider_interface import ProviderInterface
_REPLICATE_BASE_URL = "https://openai-proxy.replicate.com/v1"
class ReplicateInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Replicate's APIs."""
def __init__(self):
"""Set up the Replicate client using the API key obtained from the user's environment."""
from openai import OpenAI
self.replicate_client = OpenAI(
api_key=os.getenv("REPLICATE_API_KEY"),
base_url=_REPLICATE_BASE_URL,
)
def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Replicate API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.replicate_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)