Skip to content

Commit 92a7698

Browse files
openai fim fixed
1 parent bcd011d commit 92a7698

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

src/codegate/providers/litellmshim/litellmshim.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
import litellm
44
import structlog
55
from fastapi.responses import JSONResponse, StreamingResponse
6-
from litellm import (
7-
ChatCompletionRequest,
8-
ModelResponse,
9-
acompletion,
10-
)
6+
from litellm import ChatCompletionRequest, ModelResponse, acompletion, atext_completion
117

128
from codegate.clients.clients import ClientType
139
from codegate.providers.base import BaseCompletionHandler, StreamGenerator
@@ -52,6 +48,11 @@ async def execute_completion(
5248
request["api_key"] = api_key
5349
request["base_url"] = base_url
5450
if is_fim_request:
51+
# We need to force atext_completion if there is "prompt" in the request.
52+
# The default function acompletion can only handle "messages" in the request.
53+
if "prompt" in request:
54+
logger.debug("Forcing atext_completion in FIM")
55+
return await atext_completion(**request)
5556
return await self._fim_completion_func(**request)
5657
return await self._completion_func(**request)
5758

src/codegate/providers/normalizer/completion.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
1919
if "prompt" in data:
2020
data["messages"] = [{"content": data.pop("prompt"), "role": "user"}]
2121
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
22+
data["had_prompt_before"] = True
2223

23-
# NOTE: Not adding the flag. This will also skip the denormalize step.
24-
# LiteLLM seems to not support anymore having "prompt" as key and only "message" is
25-
# supported
26-
# data["had_prompt_before"] = True
27-
28-
# Litelllm says the we need to have max a list of length 4 in stop.
24+
# Litelllm says the we need to have max a list of length 4 in stop. Forcing it.
2925
stop_list = data.get("stop", [])
3026
trimmed_stop_list = stop_list[:4]
3127
data["stop"] = trimmed_stop_list

0 commit comments

Comments
 (0)