From 0c3f648c97f638a39d0d7ef6e15b32e76789d87d Mon Sep 17 00:00:00 2001 From: kwanUm Date: Tue, 8 Oct 2024 18:07:26 +0000 Subject: [PATCH] nits --- ldp/graph/modules/llm_call.py | 14 +++++++++++++- ldp/graph/modules/react.py | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ldp/graph/modules/llm_call.py b/ldp/graph/modules/llm_call.py index 43a786fc..3c4c41db 100644 --- a/ldp/graph/modules/llm_call.py +++ b/ldp/graph/modules/llm_call.py @@ -17,6 +17,7 @@ def __init__( self, llm_model: dict[str, Any], parser: Callable[..., TParsedMessage] ): self.config_op = ConfigOp[dict](config=llm_model) + self.llm_model = llm_model self.llm_call_op = LLMCallOp() self.parse_msg_op = FxnOp(parser) @@ -24,7 +25,18 @@ def __init__( async def __call__( self, messages: Iterable[Message], *parse_args, **parse_kwargs ) -> tuple[OpResult[TParsedMessage], Message]: - raw_result = await self.llm_call_op(await self.config_op(), msgs=messages) + if "LocalLLMCallOp" in self.llm_call_op.__class__.__name__: + print(f"STARTING A CALL TO MODEL ======================= with {len(messages.value)} messages") + for i, message in enumerate(messages.value[1:], start=1): + print(f"message{i}: {message.content}") + print(f"END OF MESSAGES =======================") + raw_result = await self.llm_call_op( + xi=messages, + temperature=self.llm_model["temperature"], + max_new_tokens=self.llm_model["max_new_tokens"], + ) + else: + raw_result = await self.llm_call_op(await self.config_op(), msgs=messages) return await self.parse_msg_op( raw_result, *parse_args, **parse_kwargs ), raw_result.value diff --git a/ldp/graph/modules/react.py b/ldp/graph/modules/react.py index bd7debcc..9369ae02 100644 --- a/ldp/graph/modules/react.py +++ b/ldp/graph/modules/react.py @@ -82,6 +82,8 @@ def parse_message(m: Message, tools: list[Tool]) -> ToolRequestMessage: # noqa: message_content = message_content[: loc if loc > 0 else None] # we need to override the message too - don't want the model to hallucinate m.content = message_content + + print("Received back!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" + m.content) action_args: tuple[Any, ...] = () # https://regex101.com/r/qmqZ7Z/1 @@ -141,6 +143,8 @@ def is_number(s: str) -> bool: action = re.search(r"Action:[ \t]*(\S*)", m.content) if not action: raise MalformedMessageError("Action not emitted.") + if "Observation:" in m.content: + raise MalformedMessageError("Observation found in message content and not expected.") tool_name = action.group(1).strip() # have to match up name to tool to line up args in order try: