Skip to content

Commit 8897318

Browse files
Fix FIM on Continue. Have specific formatters for chat and FIM (#1015)
* Fix FIM on Continue. Have specific formatters for chat and FIM Until now we had a general formatter on the way out of muxing. This is wrong since sometimes the pipelines respond with different format for chat or FIM. Such is the case for Ollama. This PR separates the formatters and declares them explicitly so that they're easier to adjust in the future. * Adressed comments from review
1 parent a7e14f6 commit 8897318

File tree

5 files changed

+212
-73
lines changed

5 files changed

+212
-73
lines changed

src/codegate/muxing/adapter.py

+136-51
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import copy
22
import json
33
import uuid
4-
from typing import Union
4+
from abc import ABC, abstractmethod
5+
from typing import Callable, Dict, Union
6+
from urllib.parse import urljoin
57

68
import structlog
79
from fastapi.responses import JSONResponse, StreamingResponse
810
from litellm import ModelResponse
911
from litellm.types.utils import Delta, StreamingChoices
10-
from ollama import ChatResponse
12+
from ollama import ChatResponse, GenerateResponse
1113

1214
from codegate.db import models as db_models
1315
from codegate.muxing import rulematcher
@@ -34,7 +36,7 @@ def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> st
3436
db_models.ProviderType.openai,
3537
db_models.ProviderType.openrouter,
3638
]:
37-
return f"{model_route.endpoint.endpoint}/v1"
39+
return urljoin(model_route.endpoint.endpoint, "/v1")
3840
return model_route.endpoint.endpoint
3941

4042
def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
@@ -45,15 +47,101 @@ def set_destination_info(self, model_route: rulematcher.ModelRoute, data: dict)
4547
return new_data
4648

4749

48-
class StreamChunkFormatter:
50+
class OutputFormatter(ABC):
51+
52+
@property
53+
@abstractmethod
54+
def provider_format_funcs(self) -> Dict[str, Callable]:
55+
"""
56+
Return the provider specific format functions. All providers format functions should
57+
return the chunk in OpenAI format.
58+
"""
59+
pass
60+
61+
@abstractmethod
62+
def format(
63+
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
64+
) -> Union[StreamingResponse, JSONResponse]:
65+
"""Format the response to the client."""
66+
pass
67+
68+
69+
class StreamChunkFormatter(OutputFormatter):
4970
"""
5071
Format a single chunk from a stream to OpenAI format.
5172
We need to configure the client to expect the OpenAI format.
5273
In Continue this means setting "provider": "openai" in the config json file.
5374
"""
5475

55-
def __init__(self):
56-
self.provider_to_func = {
76+
@property
77+
@abstractmethod
78+
def provider_format_funcs(self) -> Dict[str, Callable]:
79+
"""
80+
Return the provider specific format functions. All providers format functions should
81+
return the chunk in OpenAI format.
82+
"""
83+
pass
84+
85+
def _format_openai(self, chunk: str) -> str:
86+
"""
87+
The chunk is already in OpenAI format. To standarize remove the "data:" prefix.
88+
89+
This function is used by both chat and FIM formatters
90+
"""
91+
cleaned_chunk = chunk.split("data:")[1].strip()
92+
return cleaned_chunk
93+
94+
def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
95+
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
96+
chunk_to_send = f"data:{formatted_chunk}\n\n"
97+
return chunk_to_send
98+
99+
async def _format_streaming_response(
100+
self, response: StreamingResponse, dest_prov: db_models.ProviderType
101+
):
102+
"""Format the streaming response to OpenAI format."""
103+
format_func = self.provider_format_funcs.get(dest_prov)
104+
openai_chunk = None
105+
try:
106+
async for chunk in response.body_iterator:
107+
openai_chunk = format_func(chunk)
108+
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
109+
if not openai_chunk:
110+
continue
111+
yield self._format_as_openai_chunk(openai_chunk)
112+
except Exception as e:
113+
logger.error(f"Error sending chunk in muxing: {e}")
114+
yield self._format_as_openai_chunk(str(e))
115+
finally:
116+
# Make sure the last chunk is always [DONE]
117+
if openai_chunk and "[DONE]" not in openai_chunk:
118+
yield self._format_as_openai_chunk("[DONE]")
119+
120+
def format(
121+
self, response: StreamingResponse, dest_prov: db_models.ProviderType
122+
) -> StreamingResponse:
123+
"""Format the response to the client."""
124+
return StreamingResponse(
125+
self._format_streaming_response(response, dest_prov),
126+
status_code=response.status_code,
127+
headers=response.headers,
128+
background=response.background,
129+
media_type=response.media_type,
130+
)
131+
132+
133+
class ChatStreamChunkFormatter(StreamChunkFormatter):
134+
"""
135+
Format a single chunk from a stream to OpenAI format given that the request was a chat.
136+
"""
137+
138+
@property
139+
def provider_format_funcs(self) -> Dict[str, Callable]:
140+
"""
141+
Return the provider specific format functions. All providers format functions should
142+
return the chunk in OpenAI format.
143+
"""
144+
return {
57145
db_models.ProviderType.ollama: self._format_ollama,
58146
db_models.ProviderType.openai: self._format_openai,
59147
db_models.ProviderType.anthropic: self._format_antropic,
@@ -68,21 +156,11 @@ def _format_ollama(self, chunk: str) -> str:
68156
try:
69157
chunk_dict = json.loads(chunk)
70158
ollama_chunk = ChatResponse(**chunk_dict)
71-
open_ai_chunk = OLlamaToModel.normalize_chunk(ollama_chunk)
159+
open_ai_chunk = OLlamaToModel.normalize_chat_chunk(ollama_chunk)
72160
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
73161
except Exception:
74162
return chunk
75163

76-
def _format_openai(self, chunk: str) -> str:
77-
"""The chunk is already in OpenAI format. To standarize remove the "data:" prefix."""
78-
cleaned_chunk = chunk.split("data:")[1].strip()
79-
try:
80-
chunk_dict = json.loads(cleaned_chunk)
81-
open_ai_chunk = ModelResponse(**chunk_dict)
82-
return open_ai_chunk.model_dump_json(exclude_none=True, exclude_unset=True)
83-
except Exception:
84-
return cleaned_chunk
85-
86164
def _format_antropic(self, chunk: str) -> str:
87165
"""Format the Anthropic chunk to OpenAI format."""
88166
cleaned_chunk = chunk.split("data:")[1].strip()
@@ -119,46 +197,53 @@ def _format_antropic(self, chunk: str) -> str:
119197
except Exception:
120198
return cleaned_chunk.strip()
121199

122-
def format(self, chunk: str, dest_prov: db_models.ProviderType) -> ModelResponse:
123-
"""Format the chunk to OpenAI format."""
124-
# Get the format function
125-
format_func = self.provider_to_func.get(dest_prov)
126-
if format_func is None:
127-
raise MuxingAdapterError(f"Provider {dest_prov} not supported.")
128-
return format_func(chunk)
129200

201+
class FimStreamChunkFormatter(StreamChunkFormatter):
130202

131-
class ResponseAdapter:
203+
@property
204+
def provider_format_funcs(self) -> Dict[str, Callable]:
205+
"""
206+
Return the provider specific format functions. All providers format functions should
207+
return the chunk in OpenAI format.
208+
"""
209+
return {
210+
db_models.ProviderType.ollama: self._format_ollama,
211+
db_models.ProviderType.openai: self._format_openai,
212+
# Our Lllamacpp provider emits OpenAI chunks
213+
db_models.ProviderType.llamacpp: self._format_openai,
214+
# OpenRouter is a dialect of OpenAI
215+
db_models.ProviderType.openrouter: self._format_openai,
216+
}
217+
218+
def _format_ollama(self, chunk: str) -> str:
219+
"""Format the Ollama chunk to OpenAI format."""
220+
try:
221+
chunk_dict = json.loads(chunk)
222+
ollama_chunk = GenerateResponse(**chunk_dict)
223+
open_ai_chunk = OLlamaToModel.normalize_fim_chunk(ollama_chunk)
224+
return json.dumps(open_ai_chunk, separators=(",", ":"), indent=None)
225+
except Exception:
226+
return chunk
132227

133-
def __init__(self):
134-
self.stream_formatter = StreamChunkFormatter()
135228

136-
def _format_as_openai_chunk(self, formatted_chunk: str) -> str:
137-
"""Format the chunk as OpenAI chunk. This is the format how the clients expect the data."""
138-
return f"data:{formatted_chunk}\n\n"
229+
class ResponseAdapter:
139230

140-
async def _format_streaming_response(
141-
self, response: StreamingResponse, dest_prov: db_models.ProviderType
142-
):
143-
"""Format the streaming response to OpenAI format."""
144-
async for chunk in response.body_iterator:
145-
openai_chunk = self.stream_formatter.format(chunk, dest_prov)
146-
# Sometimes for Anthropic we couldn't get content from the chunk. Skip it.
147-
if not openai_chunk:
148-
continue
149-
yield self._format_as_openai_chunk(openai_chunk)
231+
def _get_formatter(
232+
self, response: Union[StreamingResponse, JSONResponse], is_fim_request: bool
233+
) -> OutputFormatter:
234+
"""Get the formatter based on the request type."""
235+
if isinstance(response, StreamingResponse):
236+
if is_fim_request:
237+
return FimStreamChunkFormatter()
238+
return ChatStreamChunkFormatter()
239+
raise MuxingAdapterError("Only streaming responses are supported.")
150240

151241
def format_response_to_client(
152-
self, response: Union[StreamingResponse, JSONResponse], dest_prov: db_models.ProviderType
242+
self,
243+
response: Union[StreamingResponse, JSONResponse],
244+
dest_prov: db_models.ProviderType,
245+
is_fim_request: bool,
153246
) -> Union[StreamingResponse, JSONResponse]:
154247
"""Format the response to the client."""
155-
if isinstance(response, StreamingResponse):
156-
return StreamingResponse(
157-
self._format_streaming_response(response, dest_prov),
158-
status_code=response.status_code,
159-
headers=response.headers,
160-
background=response.background,
161-
media_type=response.media_type,
162-
)
163-
else:
164-
raise MuxingAdapterError("Only streaming responses are supported.")
248+
stream_formatter = self._get_formatter(response, is_fim_request)
249+
return stream_formatter.format(response, dest_prov)

src/codegate/muxing/router.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ async def route_to_dest_provider(
9393
model=model_route.model.name,
9494
provider_type=model_route.endpoint.provider_type,
9595
provider_name=model_route.endpoint.name,
96+
is_fim_request=is_fim_request,
9697
)
9798

9899
# 2. Map the request body to the destination provider format.
@@ -108,5 +109,5 @@ async def route_to_dest_provider(
108109

109110
# 4. Transmit the response back to the client in OpenAI format.
110111
return self._response_adapter.format_response_to_client(
111-
response, model_route.endpoint.provider_type
112+
response, model_route.endpoint.provider_type, is_fim_request=is_fim_request
112113
)

src/codegate/providers/litellmshim/litellmshim.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
import litellm
44
import structlog
55
from fastapi.responses import JSONResponse, StreamingResponse
6-
from litellm import (
7-
ChatCompletionRequest,
8-
ModelResponse,
9-
acompletion,
10-
)
6+
from litellm import ChatCompletionRequest, ModelResponse, acompletion, atext_completion
117

128
from codegate.clients.clients import ClientType
139
from codegate.providers.base import BaseCompletionHandler, StreamGenerator
@@ -52,6 +48,11 @@ async def execute_completion(
5248
request["api_key"] = api_key
5349
request["base_url"] = base_url
5450
if is_fim_request:
51+
# We need to force atext_completion if there is "prompt" in the request.
52+
# The default function acompletion can only handle "messages" in the request.
53+
if "prompt" in request:
54+
logger.debug("Forcing atext_completion in FIM")
55+
return await atext_completion(**request)
5556
return await self._fim_completion_func(**request)
5657
return await self._completion_func(**request)
5758

src/codegate/providers/normalizer/completion.py

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
2020
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
2121
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
2222
data["had_prompt_before"] = True
23+
24+
# Litelllm says the we need to have max a list of length 4 in stop. Forcing it.
25+
stop_list = data.get("stop", [])
26+
trimmed_stop_list = stop_list[:4]
27+
data["stop"] = trimmed_stop_list
28+
2329
try:
2430
normalized_data = ChatCompletionRequest(**data)
2531
if normalized_data.get("stream", False):

0 commit comments

Comments
 (0)