diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py
new file mode 100644
index 0000000..dfd1e56
--- /dev/null
+++ b/lagent/actions/mcp_client.py
@@ -0,0 +1,409 @@
+import asyncio
+import logging
+import random
+import threading
+import time
+from collections import deque
+from contextlib import AsyncExitStack, nullcontext
+from typing import Deque, Literal, Optional, TypeAlias
+
+from lagent.actions.base_action import AsyncActionMixin, BaseAction
+from lagent.actions.parser import JsonParser, ParseError
+from lagent.schema import ActionReturn, ActionStatusCode
+
+ServerType: TypeAlias = Literal["stdio", "sse", "http"]
+
+logger = logging.getLogger(__name__)
+_loop = None
+
+
+def _get_event_loop():
+ try:
+ event_loop = asyncio.get_event_loop()
+ except Exception:
+ logger.warning('Can not found event loop in current thread. Create a new event loop.')
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+
+ if event_loop.is_running():
+ global _loop
+ if _loop:
+ return _loop
+
+ from threading import Thread
+
+ def _start_loop(loop):
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+ event_loop = asyncio.new_event_loop()
+ Thread(target=_start_loop, args=(event_loop,), daemon=True).start()
+ _loop = event_loop
+ return event_loop
+
+
+logger = logging.getLogger(__file__)
+
+
+class TokenBucket:
+ def __init__(self, rate_limit: float):
+ self.rate_limit = rate_limit # tokens per second
+ self.tokens = rate_limit
+ self.last_update = time.time()
+ self.lock = threading.Lock()
+
+ def acquire(self) -> bool:
+ with self.lock:
+ now = time.time()
+ # Add new tokens based on time elapsed
+ new_tokens = (now - self.last_update) * self.rate_limit
+ self.tokens = min(self.rate_limit, self.tokens + new_tokens)
+ self.last_update = now
+
+ if self.tokens >= 1:
+ self.tokens -= 1
+ return True
+ return False
+
+
+class AsyncTokenBucket:
+ def __init__(self, rate_limit: float):
+ self.rate_limit = rate_limit
+ self.capacity = rate_limit
+ self.tokens = rate_limit
+ self.last_update = time.monotonic()
+ self._lock = asyncio.Lock()
+
+ def _refill(self):
+ now = time.monotonic()
+ elapsed = now - self.last_update
+ if elapsed <= 0:
+ return
+ self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_limit)
+ self.last_update = now
+
+ async def acquire(self):
+ while True:
+ async with self._lock:
+ self._refill()
+ if self.tokens >= 1:
+ self.tokens -= 1
+ return
+ missing = 1 - self.tokens
+ wait_time = missing / self.rate_limit
+ await asyncio.sleep(wait_time)
+
+
+class FairAsyncTokenBucket:
+ def __init__(self, rate_limit: float, capacity: Optional[float] = None):
+ """
+ rate_limit: 每秒生成多少个 token
+ capacity: 桶容量(最大可累积多少 token),默认和 rate_limit 一样
+ """
+ self.rate_limit = float(rate_limit)
+ self.capacity = float(capacity) if capacity is not None else float(rate_limit)
+
+ self.tokens = self.capacity
+ self.last_update = time.monotonic()
+
+ self._lock = asyncio.Lock()
+ self._waiters: Deque[asyncio.Future] = deque()
+ self._drainer_running = False # 是否已有后台协程在发 token
+
+ # ---------- 内部工具方法 ----------
+
+ def _refill_unlocked(self) -> None:
+ """
+ 在不持锁的前提下不要调用。
+ 根据时间流逝计算当前 token 数。
+ """
+ now = time.monotonic()
+ elapsed = now - self.last_update
+ if elapsed <= 0:
+ return
+ self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_limit)
+ self.last_update = now
+
+ async def _drain_waiters(self) -> None:
+ """
+ 后台协程:按 FIFO 顺序给排队的协程发 token。
+ - 没 token 时,就 sleep 到下一个 token 产生的时间点。
+ - 有 token 且有排队,就唤醒队头的一个,再继续循环。
+ """
+ try:
+ while True:
+ fut_to_wake: Optional[asyncio.Future] = None
+ sleep_time: Optional[float] = None
+
+ async with self._lock:
+ self._refill_unlocked()
+
+ # 队列空了,没什么好做的了,退出 drainer
+ if not self._waiters:
+ self._drainer_running = False
+ return
+
+ if self.tokens >= 1:
+ # 有 token,按 FIFO 唤醒一个排队的协程
+ self.tokens -= 1
+ fut_to_wake = self._waiters.popleft()
+ sleep_time = 0.0
+ else:
+ # 没 token,算一下距离下一个 token 的时间
+ missing = 1.0 - self.tokens # 还差多少 token 才能发下一枚
+ sleep_time = max(0.0, missing / self.rate_limit)
+
+ # 出锁之后再唤醒,避免在锁里执行用户代码 / 回调
+ if fut_to_wake is not None and not fut_to_wake.done():
+ fut_to_wake.set_result(None)
+
+ # 如果刚刚唤醒了一个协程,立刻回到循环,看是否还能继续发
+ if sleep_time == 0.0:
+ continue
+
+ # 没 token,就等到有 token 再继续
+ await asyncio.sleep(sleep_time)
+ finally:
+ # 兜底,避免异常时 drainer_running 一直是 True 导致无法重启
+ async with self._lock:
+ self._drainer_running = False
+
+ # ---------- 对外接口 ----------
+
+ async def acquire(self) -> None:
+ """
+ 获取一个 token(公平:排队 FIFO)
+ """
+ loop = asyncio.get_running_loop()
+
+ # 先尝试直接拿 token(快速路径)
+ async with self._lock:
+ self._refill_unlocked()
+
+ # 如果有 token 且没有历史排队的协程,直接拿走返回
+ if self.tokens >= 1 and not self._waiters:
+ self.tokens -= 1
+ return
+
+ # 否则需要排队
+ fut = loop.create_future()
+ self._waiters.append(fut)
+
+ # 启动 drainer(只要一个就够了)
+ if not self._drainer_running:
+ self._drainer_running = True
+ asyncio.create_task(self._drain_waiters())
+
+ # 等待被 drainer 唤醒,唤醒后说明自己拿到了 token
+ await fut
+
+
+# --- 复用你原本的辅助工具 ---
+_loop = None
+
+
+def _get_event_loop():
+ try:
+ event_loop = asyncio.get_event_loop()
+ except Exception:
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+
+ if event_loop.is_running():
+ global _loop
+ if _loop:
+ return _loop
+ from threading import Thread
+
+ def _start_loop(loop):
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+ event_loop = asyncio.new_event_loop()
+ Thread(target=_start_loop, args=(event_loop,), daemon=True).start()
+ _loop = event_loop
+ return event_loop
+
+
+class AsyncMCPClient(AsyncActionMixin, BaseAction):
+ """
+ Standard Lagent Action that wraps a SINGLE tool from an MCP Server.
+
+ Robustness Fix:
+ Creates a new connection for every request and closes it immediately after.
+ This prevents connection leaks and 'ConnectTimeout' in high-concurrency RL environments.
+ """
+
+ is_stateful = False
+
+ def __init__(
+ self,
+ server_type: ServerType,
+ rate_limit: float = None,
+ max_concurrency: int = None,
+ # 注意:这里的 name 主要用于 Lagent 注册,但工具的实际元数据来自 MCP Server
+ name: Optional[str] = None,
+ **server_params,
+ ):
+ self._is_toolkit = False
+ self.server_type = server_type
+ self.server_params = server_params
+
+ # 并发控制组件
+ self.rate_limiter = FairAsyncTokenBucket(rate_limit) if rate_limit is not None else None
+ self._sem = asyncio.Semaphore(max_concurrency) if max_concurrency is not None else nullcontext()
+
+ # 1. 临时连接获取工具元数据 (Metadata)
+ # 必须在 __init__ 完成,因为 Lagent 需要 self.description
+ loop = _get_event_loop()
+ if loop.is_running():
+ fut = asyncio.run_coroutine_threadsafe(self._fetch_tool_metadata(), loop)
+ tools = fut.result()
+ else:
+ tools = loop.run_until_complete(self._fetch_tool_metadata())
+
+ # Single Action 约束:一个 Action 实例对应一个 MCP 工具
+ if len(tools) != 1:
+ logger.warning(
+ f"MCP Server returned {len(tools)} tools, but AsyncMCPAction is designed for a Single Action. "
+ f"Using the first one: {tools[0].name}"
+ )
+
+ self.tool_info = tools[0]
+ tool_name = self.tool_info.name
+
+ # 2. 初始化父类 BaseAction
+ super().__init__(
+ description={
+ 'name': tool_name,
+ 'description': self.tool_info.description,
+ 'parameters': [
+ {'name': k, 'type': v['type'].upper(), 'description': v.get('description', '')}
+ for k, v in self.tool_info.inputSchema['properties'].items()
+ ],
+ 'required': self.tool_info.inputSchema.get('required', []),
+ },
+ parser=JsonParser,
+ )
+ self._is_toolkit = False
+
+ async def _connect(self, stack: AsyncExitStack):
+ """
+ 内部辅助:建立连接并注册关闭回调。
+ 所有网络资源都注册到 `stack` 中,确保自动释放。
+ """
+ from mcp import ClientSession, StdioServerParameters
+
+ # --- Transport Layer ---
+ if self.server_type == "stdio":
+ from mcp.client.stdio import stdio_client
+
+ logger.info(
+ f"Connecting to stdio MCP server with command: {self.server_params['command']} "
+ f"{self.server_params.get('args', [])}"
+ )
+ client_kwargs = {"command": self.server_params["command"]}
+ for key in ["args", "env", "cwd"]:
+ if self.server_params.get(key) is not None:
+ client_kwargs[key] = self.server_params[key]
+
+ server_params_obj = StdioServerParameters(**client_kwargs)
+ read, write = await stack.enter_async_context(stdio_client(server_params_obj))
+
+ elif self.server_type == "sse":
+ from mcp.client.sse import sse_client
+
+ logger.info(f"Connecting to SSE MCP server at: {self.server_params['url']}")
+
+ url = self.server_params["url"]
+ target_url = random.choice(url) if isinstance(url, list) else url
+
+ client_kwargs = {"url": target_url}
+ for key in ["headers", "timeout", "sse_read_timeout"]:
+ if self.server_params.get(key) is not None:
+ client_kwargs[key] = self.server_params[key]
+
+ read, write = await stack.enter_async_context(sse_client(**client_kwargs))
+
+ elif self.server_type == "http":
+ from mcp.client.streamable_http import streamablehttp_client
+
+ logger.info(f"Connecting to StreamableHTTP MCP server at: {self.server_params['url']}")
+
+ url = self.server_params["url"]
+ target_url = random.choice(url) if isinstance(url, list) else url
+
+ client_kwargs = {"url": target_url}
+ for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]:
+ if self.server_params.get(key) is not None:
+ client_kwargs[key] = self.server_params[key]
+
+ read, write, _ = await stack.enter_async_context(streamablehttp_client(**client_kwargs))
+
+ else:
+ raise ValueError(f"Unsupported server type: {self.server_type}")
+
+ # --- Protocol Layer ---
+ session = await stack.enter_async_context(ClientSession(read, write))
+ await session.initialize()
+ return session
+
+ async def _fetch_tool_metadata(self):
+ """在 init 阶段使用一次性连接获取工具定义"""
+ async with AsyncExitStack() as stack:
+ session = await self._connect(stack)
+ result = await session.list_tools()
+ return result.tools
+
+ async def run(self, **kwargs) -> ActionReturn:
+ """
+ Standard Lagent Action Entrypoint.
+ """
+ fallback_args = kwargs.copy()
+
+ try:
+ # 1. 并发/速率控制
+ async with self._sem:
+ if self.rate_limiter is not None:
+ await self.rate_limiter.acquire()
+
+ # 2. 执行逻辑 (Critical Resource Scope)
+ # 使用 AsyncExitStack 确保本次请求结束后,HTTP连接/进程管道被彻底关闭
+ async with AsyncExitStack() as stack:
+ session = await self._connect(stack)
+
+ # 调用 MCP 工具
+ # 注意:Lagent 传入的是 kwargs 字典,MCP call_tool 正好接受字典
+ outputs_obj = await session.call_tool(self.tool_info.name, kwargs)
+
+ # 提取文本结果
+ if outputs_obj.content and hasattr(outputs_obj.content[0], 'text'):
+ outputs = outputs_obj.content[0].text
+ else:
+ outputs = str(outputs_obj)
+
+ except ParseError as exc:
+ return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR)
+ except Exception as exc:
+ # 记录详细堆栈以便调试 RL 过程中的错误
+ logger.warning(f"MCP Action {self.name} failed: {exc}")
+ return ActionReturn(fallback_args, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR)
+
+ # 3. 结果封装
+ if isinstance(outputs, ActionReturn):
+ action_return = outputs
+ if not action_return.args:
+ action_return.args = kwargs
+ if not action_return.type:
+ action_return.type = self.name
+ else:
+ # 尝试使用 JsonParser 解析结果(如果 MCP 返回的是 JSON 字符串)
+ # 否则直接作为字符串返回
+ try:
+ result = self._parser.parse_outputs(outputs)
+ except:
+ result = str(outputs)
+
+ action_return = ActionReturn(fallback_args, type=self.name, result=result)
+
+ return action_return
diff --git a/lagent/actions/web_visitor.py b/lagent/actions/web_visitor.py
new file mode 100644
index 0000000..ff8f1ea
--- /dev/null
+++ b/lagent/actions/web_visitor.py
@@ -0,0 +1,194 @@
+import asyncio
+import json
+import re
+import traceback
+import warnings
+from typing import Any, List
+
+from transformers import AutoTokenizer
+
+from lagent.actions import AsyncActionMixin, BaseAction
+from lagent.schema import ActionStatusCode, ActionValidCode, AgentMessage
+from lagent.utils import create_object
+
+
+def extract_last_json(text: str) -> dict | None:
+ """
+ Extracts the last valid JSON object from a string.
+ Handles Markdown code blocks (```json ... ```) and raw JSON strings.
+ """
+ try:
+ # 1. Try to find JSON within Markdown code blocks first
+ # Look for ```json ... ``` or just ``` ... ```
+ code_block_pattern = re.compile(r'```(?:json)?\s*(\{.*?\})\s*```', re.DOTALL)
+ matches = code_block_pattern.findall(text)
+ if matches:
+ return json.loads(matches[-1])
+
+ # 2. If no code blocks, try to find the last outermost pair of braces
+ # This regex looks for { ... } lazily but we want the last one.
+ # A simple approach for nested JSON is tricky with regex,
+ # so we scan from right to left for the last '}' and find its matching '{'.
+
+ stack, end_idx = 0, -1
+ # Reverse search to find the last valid JSON structure
+ for i in range(len(text) - 1, -1, -1):
+ char = text[i]
+ if char == '}':
+ if stack == 0:
+ end_idx = i
+ stack += 1
+ elif char == '{':
+ if stack > 0:
+ stack -= 1
+ if stack == 0 and end_idx != -1:
+ # Found a potential outermost JSON object
+ candidate = text[i : end_idx + 1]
+ try:
+ return json.loads(candidate)
+ except json.JSONDecodeError:
+ # If this chunk isn't valid, reset and keep searching backwards
+ # (or you might decide to stop here depending on strictness)
+ stack, end_idx = 0, -1
+ return None
+ except Exception:
+ return None
+
+
+class WebVisitor(AsyncActionMixin, BaseAction):
+
+ EXTRACTION_PROMPT = """Please process the following webpage content and user goal to extract relevant information:
+
+## **Webpage Content**
+{webpage_content}
+
+## **User Goal**
+{goal}
+
+## **Task Guidelines**
+1. **Content Scanning for Rationale**: Locate the **specific sections/data** directly related to the user's goal within the webpage content
+2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.
+3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal.
+
+**Final Output Format using JSON format has "rational", "evidence", "summary" feilds**
+"""
+
+ def __init__(
+ self,
+ browse_tool: BaseAction | dict,
+ llm: Any,
+ max_browse_attempts: int = 3,
+ max_extract_attempts: int = 3,
+ sleep_interval: int = 3,
+ truncate_browse_response_length: int | None = None,
+ tokenizer_path: str | None = None,
+ name: str = 'visit',
+ ):
+ super().__init__(
+ description={
+ 'name': name,
+ 'description': 'Visit webpage(s) and return the summary of the content.',
+ 'parameters': [
+ {
+ 'name': 'url',
+ 'type': ['STRING', 'ARRAY'],
+ "items": {"type": "string"},
+ "minItems": 1,
+ 'description': 'The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs.',
+ },
+ {'name': 'goal', 'type': 'STRING', 'description': 'The goal of the visit for webpage(s).'},
+ ],
+ 'required': ['url', 'goal'],
+ }
+ )
+ browse_tool = create_object(browse_tool)
+ assert not browse_tool.is_toolkit and browse_tool.description['required'] == [
+ 'url'
+ ], "browse_tool must be a single-tool action with only 'url' as required argument."
+ self.browse_tool = browse_tool
+ self.llm = create_object(llm)
+ self.max_browse_attempts = max_browse_attempts
+ self.max_extract_attempts = max_extract_attempts
+ self.sleep_interval = sleep_interval
+ self.truncate_browse_response_length = truncate_browse_response_length
+ self.tokenizer = (
+ AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) if tokenizer_path else None
+ )
+ if self.truncate_browse_response_length is not None and self.tokenizer is None:
+ warnings.warn(
+ 'truncate_browse_response_length is set but tokenizer_path is not provided. '
+ 'The raw webpage content will be truncated by characters instead of tokens.'
+ )
+
+ async def run(self, url: str | List[str], goal: str) -> str:
+ if isinstance(url, str):
+ url = [url]
+
+ async def _inner_call(single_url: str) -> str:
+ try:
+ return await self._read_webpage(single_url, goal)
+ except Exception as e:
+ return f"Error fetching {single_url}: {str(e)}"
+
+ response = await asyncio.gather(*[_inner_call(single_url) for single_url in url])
+ return "\n=======\n".join(response).strip()
+
+ async def _read_webpage(self, url: str, goal: str) -> str:
+ tool_response = compressed = None
+ return_template = (
+ f"The useful information in {url} for user goal {goal} as follows: \n\n"
+ f"Evidence in page: \n{{evidence}}\n\nSummary: \n{{summary}}\n\n"
+ )
+ for _ in range(self.max_browse_attempts):
+ resp = await self.browse_tool({'url': url})
+ if resp.valid == ActionValidCode.OPEN and resp.state == ActionStatusCode.SUCCESS:
+ tool_response = resp.format_result()
+ break
+ await asyncio.sleep(self.sleep_interval)
+ else:
+ return return_template.format(
+ evidence="The provided webpage content could not be accessed. Please check the URL or file format.",
+ summary="The webpage content could not be processed, and therefore, no information is available.",
+ )
+
+ if self.truncate_browse_response_length is not None:
+ tool_response = (
+ self.tokenizer.decode(
+ self.tokenizer.encode(
+ tool_response,
+ max_length=self.truncate_browse_response_length,
+ truncation=True,
+ add_special_tokens=False,
+ )
+ )
+ if self.tokenizer is not None
+ else tool_response[: self.truncate_browse_response_length]
+ )
+
+ for _ in range(self.max_extract_attempts):
+ try:
+ prompt = self.EXTRACTION_PROMPT.format(webpage_content=tool_response, goal=goal)
+ llm_response = await self.llm.chat([{'role': 'user', 'content': prompt}])
+ if llm_response and not isinstance(llm_response, str):
+ llm_response = (
+ llm_response.content
+ if isinstance(llm_response, AgentMessage)
+ else llm_response.choices[0].message.content
+ )
+ if not llm_response or len(llm_response) < 10:
+ tool_response = tool_response[: int(len(tool_response) * 0.7)]
+ continue
+ compressed = extract_last_json(llm_response)
+ if isinstance(compressed, dict) and all(
+ key in compressed for key in ['rational', 'evidence', 'summary']
+ ):
+ break
+ except Exception:
+ print(f"Error in extracting information: {traceback.format_exc()}")
+ await asyncio.sleep(self.sleep_interval)
+ else:
+ return return_template.format(
+ evidence="Failed to extract relevant information from the webpage content.",
+ summary="The webpage content could not be processed, and therefore, no information is available.",
+ )
+ return return_template.format(evidence=compressed['evidence'], summary=compressed['summary'])
diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py
index ae99d6d..40271ce 100644
--- a/lagent/agents/agent.py
+++ b/lagent/agents/agent.py
@@ -73,7 +73,10 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa
self.update_memory(message, session_id=session_id)
response_message = self.forward(*message, session_id=session_id, **kwargs)
if not isinstance(response_message, AgentMessage):
- response_message = AgentMessage(sender=self.name, content=response_message)
+ if isinstance(response_message, str):
+ response_message = AgentMessage(sender=self.name, content=response_message)
+ else:
+ response_message = AgentMessage.from_model_response(response_message, self.name)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
@@ -158,6 +161,28 @@ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = F
for agent in getattr(self, '_agents', {}).values():
agent.reset(session_id, recursive=True)
+ def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict]:
+ """Get OpenAI format messages from memory.
+
+ Args:
+ session_id (int): The session id of the memory.
+ keypath (Optional[str]): The keypath of the sub-agent to get messages from. Default is None.
+
+ Returns:
+ List[dict]: The messages from the memory including the sub-agent's system prompt.
+ """
+ if keypath:
+ keys, agent = keypath.split('.'), self
+ for key in keys:
+ agents = getattr(agent, '_agents', {})
+ if key not in agents:
+ raise KeyError(f'No sub-agent named {key} in {agent}')
+ agent = agents[key]
+ return agent.get_messages(session_id=session_id)
+ if self.aggregator:
+ return self.aggregator.aggregate(self.memory.get(session_id), self.name, self.output_format, self.template)
+ raise ValueError(f'{self.name} has no aggregator to get messages')
+
def __repr__(self):
def _rcsv_repr(agent, n_indent=1):
@@ -186,7 +211,10 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen
self.update_memory(message, session_id=session_id)
response_message = await self.forward(*message, session_id=session_id, **kwargs)
if not isinstance(response_message, AgentMessage):
- response_message = AgentMessage(sender=self.name, content=response_message)
+ if isinstance(response_message, str):
+ response_message = AgentMessage(sender=self.name, content=response_message)
+ else:
+ response_message = AgentMessage.from_model_response(response_message, self.name)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
diff --git a/lagent/agents/aggregator/default_aggregator.py b/lagent/agents/aggregator/default_aggregator.py
index 0888aef..720355c 100644
--- a/lagent/agents/aggregator/default_aggregator.py
+++ b/lagent/agents/aggregator/default_aggregator.py
@@ -1,31 +1,42 @@
-from typing import Dict, List
+from typing import List
from lagent.memory import Memory
from lagent.prompts import StrParser
+from lagent.schema import ActionReturn
class DefaultAggregator:
- def aggregate(self,
- messages: Memory,
- name: str,
- parser: StrParser = None,
- system_instruction: str = None) -> List[Dict[str, str]]:
+ def aggregate(self, messages: Memory, name: str, parser: StrParser = None, system_instruction=None) -> List[dict]:
_message = []
messages = messages.get_memory()
if system_instruction:
- _message.extend(
- self.aggregate_system_intruction(system_instruction))
+ _message.extend(self.aggregate_system_intruction(system_instruction))
for message in messages:
if message.sender == name:
- _message.append(
- dict(role='assistant', content=str(message.content)))
+ _message.append(message.to_model_request())
else:
- user_message = message.content
- if len(_message) > 0 and _message[-1]['role'] == 'user':
- _message[-1]['content'] += user_message
+ user_message, extra_info = message.content, message.extra_info
+ if isinstance(user_message, list):
+ for m in user_message:
+ if isinstance(m, dict):
+ m = ActionReturn(**m)
+ assert isinstance(m, ActionReturn), f"Expected m to be ActionReturn, but got {type(m)}"
+ _message.append(
+ dict(
+ role='tool',
+ tool_call_id=m.tool_call_id,
+ content=m.format_result(),
+ name=m.type,
+ extra_info=extra_info,
+ )
+ )
else:
- _message.append(dict(role='user', content=user_message))
+ if len(_message) > 0 and _message[-1]['role'] == 'user':
+ _message[-1]['content'] += user_message
+ _message[-1]['extra_info'] = extra_info
+ else:
+ _message.append(dict(role='user', content=user_message, extra_info=extra_info))
return _message
@staticmethod
@@ -39,6 +50,5 @@ def aggregate_system_intruction(system_intruction) -> List[dict]:
if not isinstance(msg, dict):
raise TypeError(f'Unsupported message type: {type(msg)}')
if not ('role' in msg and 'content' in msg):
- raise KeyError(
- f"Missing required key 'role' or 'content': {msg}")
+ raise KeyError(f"Missing required key 'role' or 'content': {msg}")
return system_intruction
diff --git a/lagent/agents/fc_agent.py b/lagent/agents/fc_agent.py
new file mode 100644
index 0000000..f3437a8
--- /dev/null
+++ b/lagent/agents/fc_agent.py
@@ -0,0 +1,160 @@
+import asyncio
+import json
+from copy import deepcopy
+from dataclasses import asdict
+from typing import Dict, List, Literal, Optional, Union
+
+from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed
+
+from lagent.actions import AsyncActionExecutor
+from lagent.hooks import Hook
+from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode, AgentMessage, AgentStatusCode
+from lagent.utils import create_object, truncate_text
+from .agent import AsyncAgent
+
+DEFAULT_TOOL_TEMPLATE = """# Tools
+
+You may call one or more functions to assist with the user query.
+
+You are provided with function signatures within XML tags:
+
+{tools}
+
+
+For each function call, return a json object with function name and arguments within XML tags:
+
+{{"name": , "arguments": }}
+"""
+
+
+def get_tool_prompt(actions: list, exclude_arguments: list = None, template: str = DEFAULT_TOOL_TEMPLATE) -> str:
+ exclude_arguments = exclude_arguments or ['session_id']
+
+ def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') -> dict:
+ properties = {}
+ for param in action_description['parameters']:
+ param = deepcopy(param)
+ param_name, param_type = param.pop('name'), param.pop('type')
+ if param_name in exclude_arguments:
+ continue
+ param_type = [t.lower() for t in param_type] if isinstance(param_type, list) else param_type.lower()
+ properties[param_name] = {'type': param_type, **param}
+ return {
+ 'type': 'function',
+ 'function': {
+ 'name': name_pattern.format(action_description['name']),
+ 'description': action_description['description'],
+ 'parameters': {'type': 'object', 'properties': properties, 'required': action_description['required']},
+ },
+ }
+
+ tools = []
+ for action in actions if isinstance(actions, list) else [actions]:
+ action = create_object(action)
+ action_desc = action.description
+ if action.is_toolkit:
+ for api in action_desc['api_list']:
+ tools.append(_convert_tool_schema(api, f"{action.name}.{{}}"))
+ else:
+ tools.append(_convert_tool_schema(action_desc))
+ return template.format(tools='\n'.join([json.dumps(tool, ensure_ascii=False) for tool in tools]))
+
+
+class FunctionCallAgent(AsyncAgent):
+ def __init__(
+ self,
+ select_agent: Union[Dict, AsyncAgent],
+ env_agent: Union[Dict, AsyncAgent],
+ finish_condition: callable = lambda x, _: x and not x.tool_calls,
+ max_turn: Optional[int] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name)
+ self.select_agent = create_object(select_agent)
+ self.env_agent = create_object(env_agent)
+ self.finish_condition = finish_condition
+ self.max_turn = max_turn
+
+ async def forward(self, env_message: AgentMessage, session_id: str | int, **kwargs):
+ selection_message: AgentMessage = None
+ current_turn = 0
+ while (
+ self.finish_condition is None
+ or not self.finish_condition(selection_message, env_message)
+ and (self.max_turn is None or current_turn < self.max_turn)
+ ):
+ selection_message = await self.select_agent(env_message, session_id=session_id, **kwargs)
+ if selection_message.stream_state == AgentStatusCode.SERVER_ERR:
+ raise ValueError("Rollout response error: state is neither completed nor aborted!")
+ if selection_message.stream_state == AgentStatusCode.SESSION_OUT_OF_LIMIT:
+ for _ in range(2): # remove the last two messages
+ self.select_agent.memory.get(session_id).delete(-1)
+ return AgentMessage(role='env', content='Context length exceeds the limit')
+ env_message = await self.env_agent(selection_message, session_id=session_id)
+ current_turn += 1
+ return AgentMessage(role="env", content="Finished")
+
+
+class EnvAgent(AsyncAgent):
+ def __init__(
+ self,
+ actions: list,
+ stateful_tools: List[str] = None,
+ max_tool_response_length: int = None,
+ tool_response_truncate_side: Literal['left', 'right', 'middle'] = 'middle',
+ action_hooks: List[Union[dict, Hook]] = None,
+ name: Optional[str] = None,
+ ):
+ super().__init__(name=name)
+ self.actions = AsyncActionExecutor(actions, hooks=action_hooks)
+ self.stateful_tools = stateful_tools or []
+ self.max_tool_response_length = max_tool_response_length
+ self.tool_response_truncate_side = tool_response_truncate_side
+ self._retry_mechanism = retry(
+ stop=stop_after_attempt(3),
+ wait=wait_fixed(2),
+ retry=retry_if_result(
+ lambda r: r.valid == ActionValidCode.OPEN
+ and r.state not in [ActionStatusCode.SUCCESS, ActionStatusCode.ARGS_ERROR]
+ ),
+ retry_error_callback=lambda retry_state: retry_state.outcome.result(),
+ )
+
+ async def forward(self, selection_message: AgentMessage, session_id: str | int, **kwargs):
+ if not selection_message.tool_calls:
+ return AgentMessage(sender=self.name, content='No tool call')
+
+ tool_responses = await asyncio.gather(
+ *[
+ self._retry_mechanism(self.execute_tool)(tool_call, session_id)
+ for tool_call in selection_message.tool_calls
+ ]
+ )
+ for tool_call_id, tool_response in zip(selection_message.tool_calls_ids, tool_responses):
+ tool_response.tool_call_id = tool_call_id
+ res = tool_response.format_result()
+ if self.max_tool_response_length is not None and len(res) > self.max_tool_response_length:
+ res = truncate_text(res, max_num=self.max_tool_response_length, side=self.tool_response_truncate_side)
+ tool_response.result = [{'type': 'text', 'content': res}]
+ return AgentMessage(sender=self.name, content=[asdict(resp) for resp in tool_responses])
+
+ async def execute_tool(self, tool_call: dict, session_id: str | int) -> ActionReturn:
+ try:
+ if tool_call['name'].split('.', 1)[0] not in self.actions:
+ return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Tool {tool_call["name"]} Not Found')
+ if isinstance(tool_call['arguments'], str):
+ tool_call['arguments'] = json.loads(tool_call['arguments'])
+ if tool_call['name'] in self.stateful_tools:
+ tool_call = deepcopy(tool_call)
+ tool_call['arguments']['session_id'] = session_id
+ except Exception as e:
+ return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Invalid tool call format: {str(e)}')
+ tool_response: ActionReturn = (
+ await self.actions(
+ AgentMessage(
+ sender='assistant', content=dict(name=tool_call['name'], parameters=tool_call['arguments'])
+ ),
+ session_id=session_id,
+ )
+ ).content
+ return tool_response
diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py
index 7418a65..e83b7eb 100644
--- a/lagent/llms/openai.py
+++ b/lagent/llms/openai.py
@@ -1,8 +1,10 @@
import asyncio
import json
import os
+import random
import time
import traceback
+import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
@@ -10,7 +12,12 @@
from typing import AsyncGenerator, Dict, List, Optional, Union
import aiohttp
+import httpx
import requests
+from openai import NOT_GIVEN, APITimeoutError, AsyncOpenAI
+from openai.types.chat import ChatCompletion
+from openai.types.chat.chat_completion import Choice
+from openai.types.chat.chat_completion_message import ChatCompletionMessage
from ..schema import ModelStatusCode
from ..utils import filter_suffix
@@ -308,7 +315,9 @@ def streaming(raw_response):
response = dict()
try:
- raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True)
+ raw_response = requests.post(
+ self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True
+ )
return streaming(raw_response)
except requests.ConnectionError:
errmsg = 'Got connection error ' + str(traceback.format_exc())
@@ -809,3 +818,79 @@ def tokenize(self, prompt: str) -> list:
self.tiktoken = tiktoken
enc = self.tiktoken.encoding_for_model(self.model_type)
return enc.encode(prompt)
+
+
+class AsyncOpenAIWrapper:
+ def __init__(
+ self,
+ model: str,
+ base_url: str | List[str],
+ api_key: str = None,
+ sample_params: dict = None,
+ proxy: str = None,
+ timeout: int = 600,
+ max_retry: int = 50,
+ sleep_interval: int = 5,
+ extra_body: dict = None,
+ ):
+ self.model = model
+ http_client = httpx.AsyncClient(proxy=proxy, timeout=timeout) if proxy else None
+ self.clients = [
+ AsyncOpenAI(api_key=api_key, base_url=url, http_client=http_client)
+ for url in (base_url if isinstance(base_url, list) else [base_url])
+ ]
+ self.sample_params = sample_params or {}
+ self.timeout = timeout
+ self.max_retry = max_retry
+ self.sleep_interval = sleep_interval
+ self.extra_body = extra_body
+
+ async def chat(self, messages: list[dict], session_id: str | int = None, **kwargs) -> ChatCompletion:
+ fallback_response = ChatCompletion(
+ id=f'chatcmpl-{uuid.uuid4()}',
+ object='chat.completion',
+ created=int(time.time()),
+ model=self.model,
+ choices=[
+ Choice(message=ChatCompletionMessage(role='assistant', content=''), finish_reason='stop', index=0)
+ ],
+ )
+ for attempt in range(self.max_retry):
+ try:
+ client = random.choice(self.clients)
+ response = await client.chat.completions.create(
+ model=self.model,
+ messages=messages,
+ stream=False,
+ temperature=kwargs.get('temperature') or self.sample_params.get("temperature", 0.7),
+ top_p=kwargs.get('top_p') or self.sample_params.get("top_p", 1.0),
+ timeout=self.timeout,
+ extra_body=self.extra_body,
+ max_tokens=kwargs.get('max_tokens') or self.sample_params.get("max_tokens", 64 * 1024),
+ reasoning_effort=kwargs.get('reasoning_effort')
+ or self.sample_params.get("reasoning_effort", NOT_GIVEN),
+ )
+ return response
+ except (APITimeoutError, TimeoutError) as e:
+ print(f"LLM Call Timeout: {e}")
+ if attempt == self.max_retry - 1:
+ return fallback_response
+ await asyncio.sleep(self.sleep_interval)
+ except Exception as e:
+ for val in [
+ "用户额度不足",
+ "剩余额度",
+ "TimeoutError",
+ "litellm.BadRequestError",
+ "litellm.APIError: APIError",
+ "Call `/v1/models` to view available models",
+ ]:
+ if val in str(e):
+ print(f"LLM Call Error: {e}")
+ if attempt == self.max_retry - 1:
+ return fallback_response
+ await asyncio.sleep(self.sleep_interval)
+ break
+ else:
+ return fallback_response
+ return fallback_response
diff --git a/lagent/memory/base_memory.py b/lagent/memory/base_memory.py
index 3c8fcf0..f7d72be 100644
--- a/lagent/memory/base_memory.py
+++ b/lagent/memory/base_memory.py
@@ -1,12 +1,11 @@
from typing import Callable, Dict, List, Optional, Union
from lagent.schema import AgentMessage
+from lagent.utils import load_class_from_string
class Memory:
- _item_cls = AgentMessage
-
def __init__(self, recent_n=None) -> None:
self.memory: List[AgentMessage] = []
self.recent_n = recent_n
@@ -25,39 +24,43 @@ def get_memory(
memory = [m for i, m in enumerate(memory) if filter_func(i, m)]
return memory
- def add(self, memories: Union[List[Dict], Dict, None]) -> None:
+ def add(self, memories: Union[List[AgentMessage | str], AgentMessage, str]) -> None:
for memory in memories if isinstance(memories, (list, tuple)) else [memories]:
if isinstance(memory, str):
- memory = self._item_cls(sender='user', content=memory)
+ memory = AgentMessage(sender='user', content=memory)
if isinstance(memory, AgentMessage):
- if not isinstance(memory, self._item_cls):
- memory = self._item_cls.model_validate(memory, from_attributes=True)
self.memory.append(memory)
- def delete(self, index: Union[List, int]) -> None:
+ def delete(self, index: Union[List[int], int]) -> None:
if isinstance(index, int):
del self.memory[index]
else:
for i in index:
del self.memory[i]
- def load(
- self,
- memories: Union[str, Dict, List],
- overwrite: bool = True,
- ) -> None:
+ def load(self, memories: Union[str, dict, List], overwrite: bool = True) -> None:
if overwrite:
self.memory = []
if isinstance(memories, dict):
- self.memory.append(self._item_cls.model_validate(memories))
+ memories = memories.copy()
+ _cls = (
+ load_class_from_string(memories.pop('__model_spec__'))
+ if '__model_spec__' in memories
+ else AgentMessage
+ )
+ self.memory.append(_cls.model_validate(memories))
elif isinstance(memories, list):
for m in memories:
- self.memory.append(self._item_cls.model_validate(m))
+ m = m.copy()
+ _cls = load_class_from_string(m.pop('__model_spec__')) if '__model_spec__' in m else AgentMessage
+ self.memory.append(_cls.model_validate(m))
else:
raise TypeError(f'{type(memories)} is not supported')
def save(self) -> List[dict]:
memory = []
for m in self.memory:
- memory.append(m.model_dump())
+ m_dumped = m.model_dump()
+ m_dumped['__model_spec__'] = f'{m.__module__}.{m.__class__.__name__}'
+ memory.append(m_dumped)
return memory
diff --git a/lagent/schema.py b/lagent/schema.py
index 668846f..ca2d135 100644
--- a/lagent/schema.py
+++ b/lagent/schema.py
@@ -2,12 +2,12 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Union
+from openai.types.chat import ChatCompletion
from pydantic import BaseModel
def enum_dict_factory(inputs):
- inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i
- for i in inputs]
+ inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i for i in inputs]
return dict(inputs)
@@ -47,6 +47,7 @@ class ActionReturn:
state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS
thought: Optional[str] = None
valid: Optional[ActionValidCode] = ActionValidCode.OPEN
+ tool_call_id: Optional[str] = None
def format_result(self) -> str:
"""Concatenate items in result."""
@@ -89,9 +90,47 @@ class AgentStatusCode(IntEnum):
class AgentMessage(BaseModel):
content: Any
+ thinking: Optional[str] = None
sender: str = 'user'
+ tool_calls: Optional[List[dict]] = None
+ tool_calls_ids: Optional[List[str]] = None
formatted: Optional[Any] = None
extra_info: Optional[Any] = None
type: Optional[str] = None
receiver: Optional[str] = None
stream_state: Union[ModelStatusCode, AgentStatusCode] = AgentStatusCode.END
+ finish_reason: Optional[str] = None
+
+ @classmethod
+ def from_model_response(cls, response: ChatCompletion, sender: str) -> "AgentMessage":
+ """Convert model response dict to AgentMessage."""
+ chat_message = response.choices[0].message
+ tool_calls = chat_message.tool_calls and [tool_call.model_dump() for tool_call in chat_message.tool_calls]
+ return cls(
+ sender=sender,
+ content=chat_message.content or "",
+ thinking=getattr(chat_message, 'reasoning_content', None),
+ tool_calls=[tool_call['function'] for tool_call in tool_calls] if tool_calls else None,
+ tool_calls_ids=[tool_call['id'] for tool_call in tool_calls] if tool_calls else None,
+ stream_state=(
+ ModelStatusCode.SESSION_OUT_OF_LIMIT
+ if response.choices[0].finish_reason == 'length'
+ else ModelStatusCode.END
+ ),
+ finish_reason=response.choices[0].finish_reason,
+ )
+
+ def to_model_request(self, role: str = 'assistant') -> dict:
+ """Convert AgentMessage to model request dict."""
+ tool_calls = [
+ {'id': tool_call_id, 'function': tool_call, 'type': 'function'}
+ for tool_call, tool_call_id in zip(self.tool_calls or [], self.tool_calls_ids or [])
+ ]
+ return {
+ "role": role,
+ "content": self.content,
+ "reasoning_content": self.thinking,
+ "tool_calls": tool_calls if tool_calls else None,
+ "extra_info": self.extra_info,
+ "stream_state": self.stream_state,
+ }
diff --git a/lagent/utils/__init__.py b/lagent/utils/__init__.py
index a0ac549..64abd61 100644
--- a/lagent/utils/__init__.py
+++ b/lagent/utils/__init__.py
@@ -6,9 +6,16 @@
filter_suffix,
get_logger,
load_class_from_string,
+ truncate_text,
)
__all__ = [
- 'is_module_exist', 'filter_suffix', 'create_object', 'get_logger',
- 'load_class_from_string', 'async_as_completed', 'GeneratorWithReturn'
+ 'is_module_exist',
+ 'filter_suffix',
+ 'create_object',
+ 'get_logger',
+ 'load_class_from_string',
+ 'async_as_completed',
+ 'GeneratorWithReturn',
+ 'truncate_text',
]
diff --git a/lagent/utils/util.py b/lagent/utils/util.py
index 609382e..9b913b2 100644
--- a/lagent/utils/util.py
+++ b/lagent/utils/util.py
@@ -4,11 +4,12 @@
import logging
import os
import os.path as osp
+import re
import sys
import time
from functools import partial
from logging.handlers import RotatingFileHandler
-from typing import Any, Dict, Generator, Iterable, List, Optional, Union
+from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast
def load_class_from_string(class_path: str, path=None):
@@ -33,6 +34,8 @@ def create_object(config: Union[Dict, Any] = None):
preserved key to indicate the class (path). When accepting non-dictionary
input, the function degenerates to an identity.
"""
+ from ray.actor import ActorClass
+
if config is None or not isinstance(config, dict):
return config
assert isinstance(config, dict) and 'type' in config
@@ -41,7 +44,9 @@ def create_object(config: Union[Dict, Any] = None):
obj_type = config.pop('type')
if isinstance(obj_type, str):
obj_type = load_class_from_string(obj_type)
- if inspect.isclass(obj_type):
+ if isinstance(obj_type, ActorClass):
+ obj = cast(ActorClass, obj_type).remote(**config)
+ elif inspect.isclass(obj_type):
obj = obj_type(**config)
else:
assert callable(obj_type)
@@ -123,6 +128,79 @@ def get_logger(
return logger
+def truncate_text(text, max_num=4000, side='middle'):
+ """
+ 中英文混合场景下,根据 side 参数截断文本。总共保留 approx max_num 个“词/字”。
+
+ 定义“单位”逻辑:
+ 1. 连续的英文/数字被视为 1 个单位 (如 "Python", "123")
+ 2. 单个汉字或标点被视为 1 个单位 (如 "中", "。", ",")
+
+ Args:
+ text (str): 原始文本
+ max_num (int): 截取的单位数量
+ side (str): 截断模式,可选 'left', 'right', 'middle'
+ 'left': 保留尾部(截断头部)
+ 'right': 保留头部(截断尾部)
+ 'middle': 保留头尾(截断中间)
+ """
+ if not text or max_num <= 0:
+ return ""
+
+ # --- 核心正则 ---
+ # 逻辑:匹配 (英文/数字/下划线/连字符 组成的词) 或 (非空白的单字符)
+ # 注意:英文的正则必须放在前面,表示优先匹配完整单词
+ pattern = re.compile(r"[a-zA-Z0-9_'-]+|[^\s]")
+
+ # 获取所有匹配对象(包含位置信息)
+ matches = list(pattern.finditer(text))
+ total_units = len(matches)
+
+ # 如果总数不够,返回全文
+ if total_units <= max_num:
+ return text
+
+ parts = []
+
+ if side == 'left':
+ # 保留尾部 max_num 个单位(截断头部)
+ # matches[-max_num] 是保留部分的第一个词
+ start_idx = total_units - max_num
+ start_pos = matches[start_idx].start()
+ parts.append("(truncated)...")
+ parts.append(text[start_pos:])
+
+ elif side == 'right':
+ # 保留头部 max_num 个单位(截断尾部)
+ # matches[max_num - 1] 是保留部分的最后一个词
+ end_pos = matches[max_num - 1].end()
+ parts.append(text[:end_pos])
+ parts.append("...(truncated)")
+
+ else: # middle
+ # --- 智能截取 (保留头尾) ---
+ head_count = max_num // 2
+ tail_count = max_num - head_count
+
+ # 1. 提取头部
+ if head_count > 0:
+ # matches[head_count - 1] 是头部想要保留的最后一个词
+ head_span_end = matches[head_count - 1].end()
+ parts.append(text[:head_span_end])
+
+ # 2. 插入截断提示
+ parts.append("...(truncated)...")
+
+ # 3. 提取尾部
+ if tail_count > 0:
+ # matches[-tail_count] 是尾部想要保留的第一个词
+ tail_idx = total_units - tail_count
+ tail_span_start = matches[tail_idx].start()
+ parts.append(text[tail_span_start:])
+
+ return "".join(parts)
+
+
class GeneratorWithReturn:
"""Generator wrapper to capture the return value."""