Skip to content

Commit 579229d

Browse files
committed
Minor fixed to openai generators.
1 parent ad45cf6 commit 579229d

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/codegate/providers/openai/provider.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
)
1919

2020

21+
logger = structlog.get_logger("codegate")
22+
23+
2124
class OpenAIProvider(BaseProvider):
2225
def __init__(
2326
self,
@@ -76,7 +79,6 @@ async def process_request(
7679
except Exception as e:
7780
#  check if we have an status code there
7881
if hasattr(e, "status_code"):
79-
logger = structlog.get_logger("codegate")
8082
logger.error("Error in OpenAIProvider completion", error=str(e))
8183

8284
raise HTTPException(status_code=e.status_code, detail=str(e)) # type: ignore
@@ -108,6 +110,10 @@ async def create_completion(
108110
req = ChatCompletionRequest.model_validate_json(body)
109111
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, req)
110112

113+
if not req.stream:
114+
logger.warn("We got a non-streaming request, forcing to a streaming one")
115+
req.stream = True
116+
111117
return await self.process_request(
112118
req,
113119
api_key,

src/codegate/types/openai/_generators.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import structlog
88

99
from ._response_models import (
10+
ChatCompletion,
1011
ErrorDetails,
1112
MessageError,
1213
StreamingChatCompletion,
@@ -62,7 +63,6 @@ async def streaming(request, api_key, url):
6263
payload = request.json(exclude_defaults=True)
6364
if os.getenv("CODEGATE_DEBUG_OPENAI") is not None:
6465
print(payload)
65-
print(headers)
6666

6767
client = httpx.AsyncClient()
6868
async with client.stream(
@@ -74,6 +74,11 @@ async def streaming(request, api_key, url):
7474
# TODO figure out how to best return failures
7575
match resp.status_code:
7676
case 200:
77+
if not request.stream:
78+
body = await resp.aread()
79+
yield ChatCompletion.model_validate_json(body)
80+
return
81+
7782
async for message in message_wrapper(resp.aiter_lines()):
7883
yield message
7984
case 400 | 401 | 403 | 404 | 413 | 429:
@@ -115,7 +120,6 @@ async def message_wrapper(lines):
115120
item = StreamingChatCompletion.model_validate_json(payload)
116121
yield item
117122
except Exception as e:
118-
print(f"WAAAGH {payload}")
119123
logger.warn("HTTP error while consuming SSE stream", payload=payload, exc_info=e)
120124
err = MessageError(
121125
error=ErrorDetails(

0 commit comments

Comments
 (0)