From 095bcbb7c263655b994d271764cad51d5125a5e9 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 11 Feb 2025 15:21:31 +0200 Subject: [PATCH 1/2] Fix FIM on Continue. Have specific formatters for chat and FIM Until now we had a general formatter on the way out of muxing. This is wrong since sometimes the pipelines respond with different format for chat or FIM. Such is the case for Ollama. This PR separates the formatters and declares them explicitly so that they're easier to adjust in the future. --- src/codegate/muxing/adapter.py | 189 +++++++++++++----- src/codegate/muxing/router.py | 3 +- .../providers/litellmshim/litellmshim.py | 11 +- .../providers/normalizer/completion.py | 6 + src/codegate/providers/ollama/adapter.py | 69 +++++-- 5 files changed, 204 insertions(+), 74 deletions(-) diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index f076df5e..d12674a9 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -1,13 +1,14 @@ import copy import json import uuid -from typing import Union +from abc import ABC, abstractmethod +from typing import Callable, Dict, Union import structlog from fastapi.responses import JSONResponse, StreamingResponse from litellm import ModelResponse from litellm.types.utils import Delta, StreamingChoices -from ollama import ChatResponse +from ollama import ChatResponse, GenerateResponse from codegate.db import models as db_models from codegate.muxing import rulematcher @@ -30,12 +31,13 @@ class BodyAdapter: def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str: """Get the provider formatted URL to use in base_url. Note this value comes from DB""" + base_endpoint = model_route.endpoint.endpoint.rstrip("/") if model_route.endpoint.provider_type in [ db_models.ProviderType.openai, db_models.ProviderType.openrouter, ]: - return f"{model_route.endpoint.endpoint}/v1" - return model_route.endpoint.endpoint + return f"{base_endpoint}/v1" + return base_endpoint def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: """Set the destination provider info.""" @@ -45,15 +47,101 @@ def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) return new_data -class StreamChunkFormatter: +class OutputFormatter(ABC): + + @property + @abstractmethod + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + pass + + @abstractmethod + def format( + self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType + ) -> Union[StreamingResponse, JSONResponse]: + """Format the response to the client.""" + pass + + +class StreamChunkFormatter(OutputFormatter): """ Format a single chunk from a stream to OpenAI format. We need to configure the client to expect the OpenAI format. In Continue this means setting "provider": "openai" in the config json file. """ - def __init__(self): - self.provider_to_func = { + @property + @abstractmethod + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + pass + + def _format_openai(self, chunk: str) -> str: + """ + The chunk is already in OpenAI format. To standarize remove the "data:" prefix. + + This function is used by both chat and FIM formatters + """ + cleaned_chunk = chunk.split("data:")[1].strip() + return cleaned_chunk + + def _format_as_openai_chunk(self, formatted_chunk: str) -> str: + """Format the chunk as OpenAI chunk. This is the format how the clients expect the data.""" + chunk_to_send = f"data:{formatted_chunk}\n\n" + return chunk_to_send + + async def _format_streaming_response( + self, response: StreamingResponse, dest_prov: db_models.ProviderType + ): + """Format the streaming response to OpenAI format.""" + format_func = self.provider_format_funcs.get(dest_prov) + openai_chunk = None + try: + async for chunk in response.body_iterator: + openai_chunk = format_func(chunk) + # Sometimes for Anthropic we couldn't get content from the chunk. Skip it. + if not openai_chunk: + continue + yield self._format_as_openai_chunk(openai_chunk) + except Exception as e: + logger.error(f"Error sending chunk in muxing: {e}") + yield self._format_as_openai_chunk(str(e)) + finally: + # Make sure the last chunk is always [DONE] + if openai_chunk and "[DONE]" not in openai_chunk: + yield self._format_as_openai_chunk("[DONE]") + + def format( + self, response: StreamingResponse, dest_prov: db_models.ProviderType + ) -> StreamingResponse: + """Format the response to the client.""" + return StreamingResponse( + self._format_streaming_response(response, dest_prov), + status_code=response.status_code, + headers=response.headers, + background=response.background, + media_type=response.media_type, + ) + + +class ChatStreamChunkFormatter(StreamChunkFormatter): + """ + Format a single chunk from a stream to OpenAI format given that the request was a chat. + """ + + @property + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + return { db_models.ProviderType.ollama: self._format_ollama, db_models.ProviderType.openai: self._format_openai, db_models.ProviderType.anthropic: self._format_antropic, @@ -68,21 +156,11 @@ def _format_ollama(self, chunk: str) -> str: try: chunk_dict = json.loads(chunk) ollama_chunk = ChatResponse(**chunk_dict) - open_ai_chunk = OLlamaToModel.normalize_chunk(ollama_chunk) + open_ai_chunk = OLlamaToModel.normalize_chat_chunk(ollama_chunk) return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) except Exception: return chunk - def _format_openai(self, chunk: str) -> str: - """The chunk is already in OpenAI format. To standarize remove the "data:" prefix.""" - cleaned_chunk = chunk.split("data:")[1].strip() - try: - chunk_dict = json.loads(cleaned_chunk) - open_ai_chunk = ModelResponse(**chunk_dict) - return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True) - except Exception: - return cleaned_chunk - def _format_antropic(self, chunk: str) -> str: """Format the Anthropic chunk to OpenAI format.""" cleaned_chunk = chunk.split("data:")[1].strip() @@ -119,46 +197,53 @@ def _format_antropic(self, chunk: str) -> str: except Exception: return cleaned_chunk.strip() - def format(self, chunk: str, dest_prov: db_models.ProviderType) -> ModelResponse: - """Format the chunk to OpenAI format.""" - # Get the format function - format_func = self.provider_to_func.get(dest_prov) - if format_func is None: - raise MuxingAdapterError(f"Provider {dest_prov} not supported.") - return format_func(chunk) +class FimStreamChunkFormatter(StreamChunkFormatter): -class ResponseAdapter: + @property + def provider_format_funcs(self) -> Dict[str, Callable]: + """ + Return the provider specific format functions. All providers format functions should + return the chunk in OpenAI format. + """ + return { + db_models.ProviderType.ollama: self._format_ollama, + db_models.ProviderType.openai: self._format_openai, + # Our Lllamacpp provider emits OpenAI chunks + db_models.ProviderType.llamacpp: self._format_openai, + # OpenRouter is a dialect of OpenAI + db_models.ProviderType.openrouter: self._format_openai, + } + + def _format_ollama(self, chunk: str) -> str: + """Format the Ollama chunk to OpenAI format.""" + try: + chunk_dict = json.loads(chunk) + ollama_chunk = GenerateResponse(**chunk_dict) + open_ai_chunk = OLlamaToModel.normalize_fim_chunk(ollama_chunk) + return json.dumps(open_ai_chunk, separators=(",", ":"), indent=None) + except Exception: + return chunk - def __init__(self): - self.stream_formatter = StreamChunkFormatter() - def _format_as_openai_chunk(self, formatted_chunk: str) -> str: - """Format the chunk as OpenAI chunk. This is the format how the clients expect the data.""" - return f"data:{formatted_chunk}\n\n" +class ResponseAdapter: - async def _format_streaming_response( - self, response: StreamingResponse, dest_prov: db_models.ProviderType - ): - """Format the streaming response to OpenAI format.""" - async for chunk in response.body_iterator: - openai_chunk = self.stream_formatter.format(chunk, dest_prov) - # Sometimes for Anthropic we couldn't get content from the chunk. Skip it. - if not openai_chunk: - continue - yield self._format_as_openai_chunk(openai_chunk) + def _get_formatter( + self, response: Union[StreamingResponse, JSONResponse], is_fim_request: bool + ) -> OutputFormatter: + """Get the formatter based on the request type.""" + if isinstance(response, StreamingResponse): + if is_fim_request: + return FimStreamChunkFormatter() + return ChatStreamChunkFormatter() + raise MuxingAdapterError("Only streaming responses are supported.") def format_response_to_client( - self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType + self, + response: Union[StreamingResponse, JSONResponse], + dest_prov: db_models.ProviderType, + is_fim_request: bool, ) -> Union[StreamingResponse, JSONResponse]: """Format the response to the client.""" - if isinstance(response, StreamingResponse): - return StreamingResponse( - self._format_streaming_response(response, dest_prov), - status_code=response.status_code, - headers=response.headers, - background=response.background, - media_type=response.media_type, - ) - else: - raise MuxingAdapterError("Only streaming responses are supported.") + stream_formatter = self._get_formatter(response, is_fim_request) + return stream_formatter.format(response, dest_prov) diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index df3a9d39..4231e8e7 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -93,6 +93,7 @@ async def route_to_dest_provider( model=model_route.model.name, provider_type=model_route.endpoint.provider_type, provider_name=model_route.endpoint.name, + is_fim_request=is_fim_request, ) # 2. Map the request body to the destination provider format. @@ -108,5 +109,5 @@ async def route_to_dest_provider( # 4. Transmit the response back to the client in OpenAI format. return self._response_adapter.format_response_to_client( - response, model_route.endpoint.provider_type + response, model_route.endpoint.provider_type, is_fim_request=is_fim_request ) diff --git a/src/codegate/providers/litellmshim/litellmshim.py b/src/codegate/providers/litellmshim/litellmshim.py index 37693f1d..eab6fc54 100644 --- a/src/codegate/providers/litellmshim/litellmshim.py +++ b/src/codegate/providers/litellmshim/litellmshim.py @@ -3,11 +3,7 @@ import litellm import structlog from fastapi.responses import JSONResponse, StreamingResponse -from litellm import ( - ChatCompletionRequest, - ModelResponse, - acompletion, -) +from litellm import ChatCompletionRequest, ModelResponse, acompletion, atext_completion from codegate.clients.clients import ClientType from codegate.providers.base import BaseCompletionHandler, StreamGenerator @@ -52,6 +48,11 @@ async def execute_completion( request["api_key"] = api_key request["base_url"] = base_url if is_fim_request: + # We need to force atext_completion if there is "prompt" in the request. + # The default function acompletion can only handle "messages" in the request. + if "prompt" in request: + logger.debug("Forcing atext_completion in FIM") + return await atext_completion(**request) return await self._fim_completion_func(**request) return await self._completion_func(**request) diff --git a/src/codegate/providers/normalizer/completion.py b/src/codegate/providers/normalizer/completion.py index c4cc6306..04227bbd 100644 --- a/src/codegate/providers/normalizer/completion.py +++ b/src/codegate/providers/normalizer/completion.py @@ -20,6 +20,12 @@ def normalize(self, data: Dict) -> ChatCompletionRequest: data["messages"] = [{"content": data.pop("prompt"), "role": "user"}] # We can add as many parameters as we like to data. ChatCompletionRequest is not strict. data["had_prompt_before"] = True + + # Litelllm says the we need to have max a list of length 4 in stop. Forcing it. + stop_list = data.get("stop", []) + trimmed_stop_list = stop_list[:4] + data["stop"] = trimmed_stop_list + try: normalized_data = ChatCompletionRequest(**data) if normalized_data.get("stream", False): diff --git a/src/codegate/providers/ollama/adapter.py b/src/codegate/providers/ollama/adapter.py index e64ec81b..07011960 100644 --- a/src/codegate/providers/ollama/adapter.py +++ b/src/codegate/providers/ollama/adapter.py @@ -1,10 +1,9 @@ -import uuid from datetime import datetime, timezone -from typing import Any, AsyncIterator, Dict, Union +from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union from litellm import ChatCompletionRequest, ModelResponse from litellm.types.utils import Delta, StreamingChoices -from ollama import ChatResponse, Message +from ollama import ChatResponse, GenerateResponse, Message from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer @@ -47,21 +46,36 @@ def __init__(self, ollama_response: AsyncIterator[ChatResponse]): self.ollama_response = ollama_response self._aiter = ollama_response.__aiter__() - @staticmethod - def normalize_chunk(chunk: ChatResponse) -> ModelResponse: - finish_reason = None - role = "assistant" - + @classmethod + def _transform_to_int_secs(cls, chunk_created_at) -> int: # Convert the datetime object to a timestamp in seconds - datetime_obj = datetime.fromisoformat(chunk.created_at) - timestamp_seconds = int(datetime_obj.timestamp()) + datetime_obj = datetime.fromisoformat(chunk_created_at) + return int(datetime_obj.timestamp()) - if chunk.done: + @classmethod + def _get_finish_reason_assistant(cls, is_chunk_done: bool) -> Tuple[str, Optional[str]]: + finish_reason = None + role = "assistant" + if is_chunk_done: finish_reason = "stop" role = None + return role, finish_reason + + @classmethod + def _get_chat_id_from_timestamp(cls, timestamp_seconds: int) -> str: + timestamp_str = str(timestamp_seconds) + return timestamp_str[:9] + + @classmethod + def normalize_chat_chunk(cls, chunk: ChatResponse) -> ModelResponse: + # Convert the datetime object to a timestamp in seconds + timestamp_seconds = cls._transform_to_int_secs(chunk.created_at) + # Get role and finish reason + role, finish_reason = cls._get_finish_reason_assistant(chunk.done) + chat_id = cls._get_chat_id_from_timestamp(timestamp_seconds) model_response = ModelResponse( - id=f"ollama-chat-{str(uuid.uuid4())}", + id=f"ollama-chat-{chat_id}", created=timestamp_seconds, model=chunk.model, object="chat.completion.chunk", @@ -76,16 +90,39 @@ def normalize_chunk(chunk: ChatResponse) -> ModelResponse: ) return model_response + @classmethod + def normalize_fim_chunk(cls, chunk: GenerateResponse) -> Dict: + """ + Transform an ollama generation chunk to an OpenAI one + """ + # Convert the datetime object to a timestamp in seconds + timestamp_seconds = cls._transform_to_int_secs(chunk.created_at) + # Get role and finish reason + _, finish_reason = cls._get_finish_reason_assistant(chunk.done) + chat_id = cls._get_chat_id_from_timestamp(timestamp_seconds) + + model_response = { + "id": f"chatcmpl-{chat_id}", + "object": "text_completion", + "created": timestamp_seconds, + "model": chunk.model, + "choices": [{"index": 0, "text": chunk.response}], + "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, + } + if finish_reason: + model_response["choices"][0]["finish_reason"] = finish_reason + del model_response["choices"][0]["text"] + return model_response + def __aiter__(self): return self async def __anext__(self): try: chunk = await self._aiter.__anext__() - if not isinstance(chunk, ChatResponse): - return chunk - - return self.normalize_chunk(chunk) + if isinstance(chunk, ChatResponse): + return self.normalize_chat_chunk(chunk) + return chunk except StopAsyncIteration: raise StopAsyncIteration From 60acedc33141d2affe5242075377aa987c23805e Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Wed, 12 Feb 2025 12:38:38 +0200 Subject: [PATCH 2/2] Adressed comments from review --- src/codegate/muxing/adapter.py | 6 +++--- src/codegate/providers/ollama/adapter.py | 21 +++++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index d12674a9..e4ac3dc2 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -3,6 +3,7 @@ import uuid from abc import ABC, abstractmethod from typing import Callable, Dict, Union +from urllib.parse import urljoin import structlog from fastapi.responses import JSONResponse, StreamingResponse @@ -31,13 +32,12 @@ class BodyAdapter: def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str: """Get the provider formatted URL to use in base_url. Note this value comes from DB""" - base_endpoint = model_route.endpoint.endpoint.rstrip("/") if model_route.endpoint.provider_type in [ db_models.ProviderType.openai, db_models.ProviderType.openrouter, ]: - return f"{base_endpoint}/v1" - return base_endpoint + return urljoin(model_route.endpoint.endpoint, "/v1") + return model_route.endpoint.endpoint def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict: """Set the destination provider info.""" diff --git a/src/codegate/providers/ollama/adapter.py b/src/codegate/providers/ollama/adapter.py index 07011960..46fc13d1 100644 --- a/src/codegate/providers/ollama/adapter.py +++ b/src/codegate/providers/ollama/adapter.py @@ -47,13 +47,18 @@ def __init__(self, ollama_response: AsyncIterator[ChatResponse]): self._aiter = ollama_response.__aiter__() @classmethod - def _transform_to_int_secs(cls, chunk_created_at) -> int: - # Convert the datetime object to a timestamp in seconds + def _transform_to_int_secs(cls, chunk_created_at: str) -> int: + """ + Convert the datetime to a timestamp in seconds. + """ datetime_obj = datetime.fromisoformat(chunk_created_at) return int(datetime_obj.timestamp()) @classmethod def _get_finish_reason_assistant(cls, is_chunk_done: bool) -> Tuple[str, Optional[str]]: + """ + Get the role and finish reason for the assistant based on the chunk done status. + """ finish_reason = None role = "assistant" if is_chunk_done: @@ -63,14 +68,20 @@ def _get_finish_reason_assistant(cls, is_chunk_done: bool) -> Tuple[str, Optiona @classmethod def _get_chat_id_from_timestamp(cls, timestamp_seconds: int) -> str: + """ + Getting a string representation of the timestamp in seconds used as the chat id. + + This needs to be done so that all chunks of a chat have the same id. + """ timestamp_str = str(timestamp_seconds) return timestamp_str[:9] @classmethod def normalize_chat_chunk(cls, chunk: ChatResponse) -> ModelResponse: - # Convert the datetime object to a timestamp in seconds + """ + Transform an ollama chat chunk to an OpenAI one + """ timestamp_seconds = cls._transform_to_int_secs(chunk.created_at) - # Get role and finish reason role, finish_reason = cls._get_finish_reason_assistant(chunk.done) chat_id = cls._get_chat_id_from_timestamp(timestamp_seconds) @@ -95,9 +106,7 @@ def normalize_fim_chunk(cls, chunk: GenerateResponse) -> Dict: """ Transform an ollama generation chunk to an OpenAI one """ - # Convert the datetime object to a timestamp in seconds timestamp_seconds = cls._transform_to_int_secs(chunk.created_at) - # Get role and finish reason _, finish_reason = cls._get_finish_reason_assistant(chunk.done) chat_id = cls._get_chat_id_from_timestamp(timestamp_seconds)