Skip to content

Fix FIM on Continue. Have specific formatters for chat and FIM #1015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 136 additions & 51 deletions src/codegate/muxing/adapter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import copy
import json
import uuid
from typing import Union
from abc import ABC, abstractmethod
from typing import Callable, Dict, Union
from urllib.parse import urljoin

import structlog
from fastapi.responses import JSONResponse, StreamingResponse
from litellm import ModelResponse
from litellm.types.utils import Delta, StreamingChoices
from ollama import ChatResponse
from ollama import ChatResponse, GenerateResponse

from codegate.db import models as db_models
from codegate.muxing import rulematcher
Expand All @@ -34,7 +36,7 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st
db_models.ProviderType.openai,
db_models.ProviderType.openrouter,
]:
return f"{model_route.endpoint.endpoint}/v1"
return urljoin(model_route.endpoint.endpoint, "/v1")
return model_route.endpoint.endpoint

def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
Expand All @@ -45,15 +47,101 @@ def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict)
return new_data


class StreamChunkFormatter:
class OutputFormatter(ABC):

@property
@abstractmethod
def provider_format_funcs(self) -> Dict[str, Callable]:
"""
Return the provider specific format functions. All providers format functions should
return the chunk in OpenAI format.
"""
pass

@abstractmethod
def format(
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
) -> Union[StreamingResponse, JSONResponse]:
"""Format the response to the client."""
pass


class StreamChunkFormatter(OutputFormatter):
"""
Format a single chunk from a stream to OpenAI format.
We need to configure the client to expect the OpenAI format.
In Continue this means setting "provider": "openai" in the config json file.
"""

def __init__(self):
self.provider_to_func = {
@property
@abstractmethod
def provider_format_funcs(self) -> Dict[str, Callable]:
"""
Return the provider specific format functions. All providers format functions should
return the chunk in OpenAI format.
"""
pass

def _format_openai(self, chunk: str) -> str:
"""
The chunk is already in OpenAI format. To standarize remove the "data:" prefix.

This function is used by both chat and FIM formatters
"""
cleaned_chunk = chunk.split("data:")[1].strip()
return cleaned_chunk

def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
chunk_to_send = f"data:{formatted_chunk}\n\n"
return chunk_to_send

async def _format_streaming_response(
self, response: StreamingResponse, dest_prov: db_models.ProviderType
):
"""Format the streaming response to OpenAI format."""
format_func = self.provider_format_funcs.get(dest_prov)
openai_chunk = None
try:
async for chunk in response.body_iterator:
openai_chunk = format_func(chunk)
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
if not openai_chunk:
continue
yield self._format_as_openai_chunk(openai_chunk)
except Exception as e:
logger.error(f"Error sending chunk in muxing: {e}")
yield self._format_as_openai_chunk(str(e))
finally:
# Make sure the last chunk is always [DONE]
if openai_chunk and "[DONE]" not in openai_chunk:
yield self._format_as_openai_chunk("[DONE]")

def format(
self, response: StreamingResponse, dest_prov: db_models.ProviderType
) -> StreamingResponse:
"""Format the response to the client."""
return StreamingResponse(
self._format_streaming_response(response, dest_prov),
status_code=response.status_code,
headers=response.headers,
background=response.background,
media_type=response.media_type,
)


class ChatStreamChunkFormatter(StreamChunkFormatter):
"""
Format a single chunk from a stream to OpenAI format given that the request was a chat.
"""

@property
def provider_format_funcs(self) -> Dict[str, Callable]:
"""
Return the provider specific format functions. All providers format functions should
return the chunk in OpenAI format.
"""
return {
db_models.ProviderType.ollama: self._format_ollama,
db_models.ProviderType.openai: self._format_openai,
db_models.ProviderType.anthropic: self._format_antropic,
Expand All @@ -68,21 +156,11 @@ def _format_ollama(self, chunk: str) -> str:
try:
chunk_dict = json.loads(chunk)
ollama_chunk = ChatResponse(**chunk_dict)
open_ai_chunk = OLlamaToModel.normalize_chunk(ollama_chunk)
open_ai_chunk = OLlamaToModel.normalize_chat_chunk(ollama_chunk)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception:
return chunk

def _format_openai(self, chunk: str) -> str:
"""The chunk is already in OpenAI format. To standarize remove the "data:" prefix."""
cleaned_chunk = chunk.split("data:")[1].strip()
try:
chunk_dict = json.loads(cleaned_chunk)
open_ai_chunk = ModelResponse(**chunk_dict)
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
except Exception:
return cleaned_chunk

def _format_antropic(self, chunk: str) -> str:
"""Format the Anthropic chunk to OpenAI format."""
cleaned_chunk = chunk.split("data:")[1].strip()
Expand Down Expand Up @@ -119,46 +197,53 @@ def _format_antropic(self, chunk: str) -> str:
except Exception:
return cleaned_chunk.strip()

def format(self, chunk: str, dest_prov: db_models.ProviderType) -> ModelResponse:
"""Format the chunk to OpenAI format."""
# Get the format function
format_func = self.provider_to_func.get(dest_prov)
if format_func is None:
raise MuxingAdapterError(f"Provider {dest_prov} not supported.")
return format_func(chunk)

class FimStreamChunkFormatter(StreamChunkFormatter):

class ResponseAdapter:
@property
def provider_format_funcs(self) -> Dict[str, Callable]:
"""
Return the provider specific format functions. All providers format functions should
return the chunk in OpenAI format.
"""
return {
db_models.ProviderType.ollama: self._format_ollama,
db_models.ProviderType.openai: self._format_openai,
# Our Lllamacpp provider emits OpenAI chunks
db_models.ProviderType.llamacpp: self._format_openai,
# OpenRouter is a dialect of OpenAI
db_models.ProviderType.openrouter: self._format_openai,
}

def _format_ollama(self, chunk: str) -> str:
"""Format the Ollama chunk to OpenAI format."""
try:
chunk_dict = json.loads(chunk)
ollama_chunk = GenerateResponse(**chunk_dict)
open_ai_chunk = OLlamaToModel.normalize_fim_chunk(ollama_chunk)
return json.dumps(open_ai_chunk, separators=(",", ":"), indent=None)
except Exception:
return chunk

def __init__(self):
self.stream_formatter = StreamChunkFormatter()

def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
return f"data:{formatted_chunk}\n\n"
class ResponseAdapter:

async def _format_streaming_response(
self, response: StreamingResponse, dest_prov: db_models.ProviderType
):
"""Format the streaming response to OpenAI format."""
async for chunk in response.body_iterator:
openai_chunk = self.stream_formatter.format(chunk, dest_prov)
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
if not openai_chunk:
continue
yield self._format_as_openai_chunk(openai_chunk)
def _get_formatter(
self, response: Union[StreamingResponse, JSONResponse], is_fim_request: bool
) -> OutputFormatter:
"""Get the formatter based on the request type."""
if isinstance(response, StreamingResponse):
if is_fim_request:
return FimStreamChunkFormatter()
return ChatStreamChunkFormatter()
raise MuxingAdapterError("Only streaming responses are supported.")

def format_response_to_client(
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
self,
response: Union[StreamingResponse, JSONResponse],
dest_prov: db_models.ProviderType,
is_fim_request: bool,
) -> Union[StreamingResponse, JSONResponse]:
"""Format the response to the client."""
if isinstance(response, StreamingResponse):
return StreamingResponse(
self._format_streaming_response(response, dest_prov),
status_code=response.status_code,
headers=response.headers,
background=response.background,
media_type=response.media_type,
)
else:
raise MuxingAdapterError("Only streaming responses are supported.")
stream_formatter = self._get_formatter(response, is_fim_request)
return stream_formatter.format(response, dest_prov)
3 changes: 2 additions & 1 deletion src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def route_to_dest_provider(
model=model_route.model.name,
provider_type=model_route.endpoint.provider_type,
provider_name=model_route.endpoint.name,
is_fim_request=is_fim_request,
)

# 2. Map the request body to the destination provider format.
Expand All @@ -108,5 +109,5 @@ async def route_to_dest_provider(

# 4. Transmit the response back to the client in OpenAI format.
return self._response_adapter.format_response_to_client(
response, model_route.endpoint.provider_type
response, model_route.endpoint.provider_type, is_fim_request=is_fim_request
)
11 changes: 6 additions & 5 deletions src/codegate/providers/litellmshim/litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import litellm
import structlog
from fastapi.responses import JSONResponse, StreamingResponse
from litellm import (
ChatCompletionRequest,
ModelResponse,
acompletion,
)
from litellm import ChatCompletionRequest, ModelResponse, acompletion, atext_completion

from codegate.clients.clients import ClientType
from codegate.providers.base import BaseCompletionHandler, StreamGenerator
Expand Down Expand Up @@ -52,6 +48,11 @@ async def execute_completion(
request["api_key"] = api_key
request["base_url"] = base_url
if is_fim_request:
# We need to force atext_completion if there is "prompt" in the request.
# The default function acompletion can only handle "messages" in the request.
if "prompt" in request:
logger.debug("Forcing atext_completion in FIM")
return await atext_completion(**request)
return await self._fim_completion_func(**request)
return await self._completion_func(**request)

Expand Down
6 changes: 6 additions & 0 deletions src/codegate/providers/normalizer/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
data["had_prompt_before"] = True

# Litelllm says the we need to have max a list of length 4 in stop. Forcing it.
stop_list = data.get("stop", [])
trimmed_stop_list = stop_list[:4]
data["stop"] = trimmed_stop_list

try:
normalized_data = ChatCompletionRequest(**data)
if normalized_data.get("stream", False):
Expand Down
Loading