Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
409 changes: 409 additions & 0 deletions lagent/actions/mcp_client.py

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions lagent/actions/web_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio
import json
import re
import traceback
import warnings
from typing import Any, List

from transformers import AutoTokenizer

from lagent.actions import AsyncActionMixin, BaseAction
from lagent.schema import ActionStatusCode, ActionValidCode, AgentMessage
from lagent.utils import create_object


def extract_last_json(text: str) -> dict | None:
"""
Extracts the last valid JSON object from a string.
Handles Markdown code blocks (```json ... ```) and raw JSON strings.
"""
try:
# 1. Try to find JSON within Markdown code blocks first
# Look for ```json ... ``` or just ``` ... ```
code_block_pattern = re.compile(r'```(?:json)?\s*(\{.*?\})\s*```', re.DOTALL)
matches = code_block_pattern.findall(text)
if matches:
return json.loads(matches[-1])

# 2. If no code blocks, try to find the last outermost pair of braces
# This regex looks for { ... } lazily but we want the last one.
# A simple approach for nested JSON is tricky with regex,
# so we scan from right to left for the last '}' and find its matching '{'.

stack, end_idx = 0, -1
# Reverse search to find the last valid JSON structure
for i in range(len(text) - 1, -1, -1):
char = text[i]
if char == '}':
if stack == 0:
end_idx = i
stack += 1
elif char == '{':
if stack > 0:
stack -= 1
if stack == 0 and end_idx != -1:
# Found a potential outermost JSON object
candidate = text[i : end_idx + 1]
try:
return json.loads(candidate)
except json.JSONDecodeError:
# If this chunk isn't valid, reset and keep searching backwards
# (or you might decide to stop here depending on strictness)
stack, end_idx = 0, -1
return None
except Exception:
return None


class WebVisitor(AsyncActionMixin, BaseAction):

EXTRACTION_PROMPT = """Please process the following webpage content and user goal to extract relevant information:

## **Webpage Content**
{webpage_content}

## **User Goal**
{goal}

## **Task Guidelines**
1. **Content Scanning for Rationale**: Locate the **specific sections/data** directly related to the user's goal within the webpage content
2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.
3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal.

**Final Output Format using JSON format has "rational", "evidence", "summary" feilds**
"""

def __init__(
self,
browse_tool: BaseAction | dict,
llm: Any,
max_browse_attempts: int = 3,
max_extract_attempts: int = 3,
sleep_interval: int = 3,
truncate_browse_response_length: int | None = None,
tokenizer_path: str | None = None,
name: str = 'visit',
):
super().__init__(
description={
'name': name,
'description': 'Visit webpage(s) and return the summary of the content.',
'parameters': [
{
'name': 'url',
'type': ['STRING', 'ARRAY'],
"items": {"type": "string"},
"minItems": 1,
'description': 'The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs.',
},
{'name': 'goal', 'type': 'STRING', 'description': 'The goal of the visit for webpage(s).'},
],
'required': ['url', 'goal'],
}
)
browse_tool = create_object(browse_tool)
assert not browse_tool.is_toolkit and browse_tool.description['required'] == [
'url'
], "browse_tool must be a single-tool action with only 'url' as required argument."
self.browse_tool = browse_tool
self.llm = create_object(llm)
self.max_browse_attempts = max_browse_attempts
self.max_extract_attempts = max_extract_attempts
self.sleep_interval = sleep_interval
self.truncate_browse_response_length = truncate_browse_response_length
self.tokenizer = (
AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) if tokenizer_path else None
)
if self.truncate_browse_response_length is not None and self.tokenizer is None:
warnings.warn(
'truncate_browse_response_length is set but tokenizer_path is not provided. '
'The raw webpage content will be truncated by characters instead of tokens.'
)

async def run(self, url: str | List[str], goal: str) -> str:
if isinstance(url, str):
url = [url]

async def _inner_call(single_url: str) -> str:
try:
return await self._read_webpage(single_url, goal)
except Exception as e:
return f"Error fetching {single_url}: {str(e)}"

response = await asyncio.gather(*[_inner_call(single_url) for single_url in url])
return "\n=======\n".join(response).strip()

async def _read_webpage(self, url: str, goal: str) -> str:
tool_response = compressed = None
return_template = (
f"The useful information in {url} for user goal {goal} as follows: \n\n"
f"Evidence in page: \n{{evidence}}\n\nSummary: \n{{summary}}\n\n"
)
for _ in range(self.max_browse_attempts):
resp = await self.browse_tool({'url': url})
if resp.valid == ActionValidCode.OPEN and resp.state == ActionStatusCode.SUCCESS:
tool_response = resp.format_result()
break
await asyncio.sleep(self.sleep_interval)
else:
return return_template.format(
evidence="The provided webpage content could not be accessed. Please check the URL or file format.",
summary="The webpage content could not be processed, and therefore, no information is available.",
)

if self.truncate_browse_response_length is not None:
tool_response = (
self.tokenizer.decode(
self.tokenizer.encode(
tool_response,
max_length=self.truncate_browse_response_length,
truncation=True,
add_special_tokens=False,
)
)
if self.tokenizer is not None
else tool_response[: self.truncate_browse_response_length]
)

for _ in range(self.max_extract_attempts):
try:
prompt = self.EXTRACTION_PROMPT.format(webpage_content=tool_response, goal=goal)
llm_response = await self.llm.chat([{'role': 'user', 'content': prompt}])
if llm_response and not isinstance(llm_response, str):
llm_response = (
llm_response.content
if isinstance(llm_response, AgentMessage)
else llm_response.choices[0].message.content
)
if not llm_response or len(llm_response) < 10:
tool_response = tool_response[: int(len(tool_response) * 0.7)]
continue
compressed = extract_last_json(llm_response)
if isinstance(compressed, dict) and all(
key in compressed for key in ['rational', 'evidence', 'summary']
):
break
except Exception:
print(f"Error in extracting information: {traceback.format_exc()}")
await asyncio.sleep(self.sleep_interval)
else:
return return_template.format(
evidence="Failed to extract relevant information from the webpage content.",
summary="The webpage content could not be processed, and therefore, no information is available.",
)
return return_template.format(evidence=compressed['evidence'], summary=compressed['summary'])
32 changes: 30 additions & 2 deletions lagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa
self.update_memory(message, session_id=session_id)
response_message = self.forward(*message, session_id=session_id, **kwargs)
if not isinstance(response_message, AgentMessage):
response_message = AgentMessage(sender=self.name, content=response_message)
if isinstance(response_message, str):
response_message = AgentMessage(sender=self.name, content=response_message)
else:
response_message = AgentMessage.from_model_response(response_message, self.name)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
Expand Down Expand Up @@ -158,6 +161,28 @@ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = F
for agent in getattr(self, '_agents', {}).values():
agent.reset(session_id, recursive=True)

def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict]:
"""Get OpenAI format messages from memory.

Args:
session_id (int): The session id of the memory.
keypath (Optional[str]): The keypath of the sub-agent to get messages from. Default is None.

Returns:
List[dict]: The messages from the memory including the sub-agent's system prompt.
"""
if keypath:
keys, agent = keypath.split('.'), self
for key in keys:
agents = getattr(agent, '_agents', {})
if key not in agents:
raise KeyError(f'No sub-agent named {key} in {agent}')
agent = agents[key]
return agent.get_messages(session_id=session_id)
if self.aggregator:
return self.aggregator.aggregate(self.memory.get(session_id), self.name, self.output_format, self.template)
raise ValueError(f'{self.name} has no aggregator to get messages')

def __repr__(self):

def _rcsv_repr(agent, n_indent=1):
Expand Down Expand Up @@ -186,7 +211,10 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen
self.update_memory(message, session_id=session_id)
response_message = await self.forward(*message, session_id=session_id, **kwargs)
if not isinstance(response_message, AgentMessage):
response_message = AgentMessage(sender=self.name, content=response_message)
if isinstance(response_message, str):
response_message = AgentMessage(sender=self.name, content=response_message)
else:
response_message = AgentMessage.from_model_response(response_message, self.name)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
Expand Down
42 changes: 26 additions & 16 deletions lagent/agents/aggregator/default_aggregator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
from typing import Dict, List
from typing import List

from lagent.memory import Memory
from lagent.prompts import StrParser
from lagent.schema import ActionReturn


class DefaultAggregator:

def aggregate(self,
messages: Memory,
name: str,
parser: StrParser = None,
system_instruction: str = None) -> List[Dict[str, str]]:
def aggregate(self, messages: Memory, name: str, parser: StrParser = None, system_instruction=None) -> List[dict]:
_message = []
messages = messages.get_memory()
if system_instruction:
_message.extend(
self.aggregate_system_intruction(system_instruction))
_message.extend(self.aggregate_system_intruction(system_instruction))
for message in messages:
if message.sender == name:
_message.append(
dict(role='assistant', content=str(message.content)))
_message.append(message.to_model_request())
else:
user_message = message.content
if len(_message) > 0 and _message[-1]['role'] == 'user':
_message[-1]['content'] += user_message
user_message, extra_info = message.content, message.extra_info
if isinstance(user_message, list):
for m in user_message:
if isinstance(m, dict):
m = ActionReturn(**m)
assert isinstance(m, ActionReturn), f"Expected m to be ActionReturn, but got {type(m)}"
_message.append(
dict(
role='tool',
tool_call_id=m.tool_call_id,
content=m.format_result(),
name=m.type,
extra_info=extra_info,
)
)
else:
_message.append(dict(role='user', content=user_message))
if len(_message) > 0 and _message[-1]['role'] == 'user':
_message[-1]['content'] += user_message
_message[-1]['extra_info'] = extra_info
else:
_message.append(dict(role='user', content=user_message, extra_info=extra_info))
return _message

@staticmethod
Expand All @@ -39,6 +50,5 @@ def aggregate_system_intruction(system_intruction) -> List[dict]:
if not isinstance(msg, dict):
raise TypeError(f'Unsupported message type: {type(msg)}')
if not ('role' in msg and 'content' in msg):
raise KeyError(
f"Missing required key 'role' or 'content': {msg}")
raise KeyError(f"Missing required key 'role' or 'content': {msg}")
return system_intruction
Loading