Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit a148a16

Browse files
jhrozekblkt
authored andcommitted
Fixes to last_user_block (#1077)
* Improvements to get_last_user_message_block * Fix test_messages_block
1 parent 39efa62 commit a148a16

File tree

7 files changed

+148
-140
lines changed

7 files changed

+148
-140
lines changed

src/codegate/pipeline/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,12 @@ def get_last_user_message(
208208
@staticmethod
209209
def get_last_user_message_block(
210210
request: ChatCompletionRequest,
211-
client: ClientType = ClientType.GENERIC,
212211
) -> Optional[tuple[str, int]]:
213212
"""
214213
Get the last block of consecutive 'user' messages from the request.
215214
216215
Args:
217216
request (ChatCompletionRequest): The chat completion request to process
218-
client (ClientType): The client type to consider when processing the request
219217
220218
Returns:
221219
Optional[str, int]: A string containing all consecutive user messages in the

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def process( # noqa: C901
7575
Use RAG DB to add context to the user request
7676
"""
7777
# Get the latest user message
78-
last_message = self.get_last_user_message_block(request, context.client)
78+
last_message = self.get_last_user_message_block(request)
7979
if not last_message:
8080
return PipelineResult(request=request)
8181
user_message, last_user_idx = last_message

src/codegate/pipeline/secrets/secrets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ async def process(
295295
total_matches = []
296296

297297
# get last user message block to get index for the first relevant user message
298-
last_user_message = self.get_last_user_message_block(request, context.client)
298+
last_user_message = self.get_last_user_message_block(request)
299299
last_assistant_idx = last_user_message[1] - 1 if last_user_message else -1
300300

301301
# Process all messages

src/codegate/types/anthropic/_request_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def last_user_block(self) -> Iterable[tuple[Message, int]]:
213213
for idx, msg in enumerate(reversed(self.messages)):
214214
if isinstance(msg, UserMessage):
215215
yield msg, len(self.messages) - 1 - idx
216-
break
217216

218217
def get_system_prompt(self) -> Iterable[str]:
219218
if isinstance(self.system, str):

src/codegate/types/ollama/_request_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ def last_user_block(self) -> Iterable[tuple[Message, int]]:
144144
for idx, msg in enumerate(reversed(self.messages)):
145145
if isinstance(msg, UserMessage):
146146
yield msg, len(self.messages) - 1 - idx
147+
elif isinstance(msg, (SystemMessage, ToolMessage)):
148+
# these can occur in the middle of a user block
149+
continue
150+
elif isinstance(msg, AssistantMessage):
151+
# these are LLM responses, end of user input, break on them
147152
break
148153

149154
def get_system_prompt(self) -> Iterable[str]:

src/codegate/types/openai/_request_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,11 @@ def last_user_block(self) -> Iterable[tuple[Message, int]]:
341341
for idx, msg in enumerate(reversed(self.messages)):
342342
if isinstance(msg, UserMessage):
343343
yield msg, len(self.messages) - 1 - idx
344+
elif isinstance(msg, (SystemMessage, DeveloperMessage, ToolMessage)):
345+
# these can occur in the middle of a user block
346+
continue
347+
elif isinstance(msg, (AssistantMessage, FunctionMessage)):
348+
# these are LLM responses, end of user input, break on them
344349
break
345350

346351
def get_system_prompt(self) -> Iterable[str]:

tests/pipeline/test_messages_block.py

Lines changed: 136 additions & 135 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)