Skip to content

Commit b23effd

Browse files
authored
fix: only record db content if it is the last chunk in stream (#1091)
For copilot chat, we are actually receiving multiple streams, and we were recording entries and alerts for each one, repeating those. So detect if we are in the last chunk and propagate it to the pipeline, so we can record it only in this case, avoiding dupes. In the case of other providers, we only receive one request, so always force saving it Closes: #936
1 parent 8d738a4 commit b23effd

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

src/codegate/db/connection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
129129
active_workspace = await DbReader().get_active_workspace()
130130
workspace_id = active_workspace.id if active_workspace else "1"
131131
prompt_params.workspace_id = workspace_id
132+
132133
sql = text(
133134
"""
134135
INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id)
@@ -302,7 +303,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
302303
await self.record_outputs(context.output_responses, initial_id)
303304
await self.record_alerts(context.alerts_raised, initial_id)
304305
logger.info(
305-
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
306+
f"Updated context in DB. Output chunks: {len(context.output_responses)}. "
306307
f"Alerts: {len(context.alerts_raised)}."
307308
)
308309
except Exception as e:

src/codegate/pipeline/output.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ def _record_to_db(self) -> None:
127127
loop.create_task(self._db_recorder.record_context(self._input_context))
128128

129129
async def process_stream(
130-
self, stream: AsyncIterator[ModelResponse], cleanup_sensitive: bool = True
130+
self,
131+
stream: AsyncIterator[ModelResponse],
132+
cleanup_sensitive: bool = True,
133+
finish_stream: bool = True,
131134
) -> AsyncIterator[ModelResponse]:
132135
"""
133136
Process a stream through all pipeline steps
@@ -167,7 +170,7 @@ async def process_stream(
167170
finally:
168171
# NOTE: Don't use await in finally block, it will break the stream
169172
# Don't flush the buffer if we assume we'll call the pipeline again
170-
if cleanup_sensitive is False:
173+
if cleanup_sensitive is False and finish_stream:
171174
self._record_to_db()
172175
return
173176

@@ -194,7 +197,8 @@ async def process_stream(
194197
yield chunk
195198
self._context.buffer.clear()
196199

197-
self._record_to_db()
200+
if finish_stream:
201+
self._record_to_db()
198202
# Cleanup sensitive data through the input context
199203
if cleanup_sensitive and self._input_context and self._input_context.sensitive:
200204
self._input_context.sensitive.secure_cleanup()

src/codegate/providers/copilot/provider.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -905,8 +905,16 @@ async def stream_iterator():
905905
)
906906
yield mr
907907

908+
# needs to be set as the flag gets reset on finish_data
909+
finish_stream_flag = any(
910+
choice.get("finish_reason") == "stop"
911+
for record in list(self.stream_queue._queue)
912+
for choice in record.get("content", {}).get("choices", [])
913+
)
908914
async for record in self.output_pipeline_instance.process_stream(
909-
stream_iterator(), cleanup_sensitive=False
915+
stream_iterator(),
916+
cleanup_sensitive=False,
917+
finish_stream=finish_stream_flag,
910918
):
911919
chunk = record.model_dump_json(exclude_none=True, exclude_unset=True)
912920
sse_data = f"data: {chunk}\n\n".encode("utf-8")

0 commit comments

Comments
 (0)