|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import json |
| 4 | +from dataclasses import dataclass |
| 5 | +from typing import AsyncIterable |
| 6 | + |
| 7 | +import httpx |
| 8 | +import httpx_sse |
| 9 | +from fastapi_poe import PoeBot |
| 10 | +from fastapi_poe.types import QueryRequest |
| 11 | +from sse_starlette.sse import ServerSentEvent |
| 12 | + |
| 13 | +BASE_URL = "https://api.together.xyz/inference" |
| 14 | +DEFAULT_SYSTEM_PROMPT = """\ |
| 15 | +You are the StarCoderChat bot. You help users with programming and code related questions. |
| 16 | +Wrap any code blocks in your response in backticks so that it can be rendered using Markdown.""" |
| 17 | + |
| 18 | + |
| 19 | +@dataclass |
| 20 | +class StarCoderChatBot(PoeBot): |
| 21 | + TOGETHER_API_KEY: str # Together.ai api key |
| 22 | + |
| 23 | + def construct_prompt(self, query: QueryRequest): |
| 24 | + prompt = "\n" |
| 25 | + prompt += f"<system>: {DEFAULT_SYSTEM_PROMPT}\n" |
| 26 | + for message in query.query: |
| 27 | + if message.role == "user": |
| 28 | + prompt += f"<human>: {message.content}\n" |
| 29 | + elif message.role == "bot": |
| 30 | + prompt += f"<bot>: {message.content}\n" |
| 31 | + elif message.role == "system": |
| 32 | + pass |
| 33 | + else: |
| 34 | + raise ValueError(f"unknown role {message.role}.") |
| 35 | + prompt += "<bot>:" |
| 36 | + return prompt |
| 37 | + |
| 38 | + async def query_together_ai(self, prompt) -> str: |
| 39 | + payload = { |
| 40 | + "model": "HuggingFaceH4/starchat-alpha", |
| 41 | + "prompt": prompt, |
| 42 | + "max_tokens": 1000, |
| 43 | + "stop": ["<|endoftext|>", "<|end|>", "<human>", "<bot>"], |
| 44 | + "stream_tokens": True, |
| 45 | + "temperature": 0.7, |
| 46 | + "top_p": 0.7, |
| 47 | + "top_k": 50, |
| 48 | + "repetition_penalty": 1, |
| 49 | + } |
| 50 | + headers = { |
| 51 | + "accept": "application/json", |
| 52 | + "content-type": "application/json", |
| 53 | + "Authorization": f"Bearer {self.TOGETHER_API_KEY}", |
| 54 | + } |
| 55 | + |
| 56 | + async with httpx.AsyncClient() as aclient: |
| 57 | + async with httpx_sse.aconnect_sse( |
| 58 | + aclient, "POST", BASE_URL, headers=headers, json=payload |
| 59 | + ) as event_source: |
| 60 | + async for event in event_source.aiter_sse(): |
| 61 | + if event.data != "[DONE]": |
| 62 | + token = json.loads(event.data)["choices"][0]["text"] |
| 63 | + yield token |
| 64 | + |
| 65 | + async def get_response(self, query: QueryRequest) -> AsyncIterable[ServerSentEvent]: |
| 66 | + prompt = self.construct_prompt(query) |
| 67 | + async for word in self.query_together_ai(prompt): |
| 68 | + yield self.text_event(word) |
0 commit comments