diff --git a/lagent/actions/mcp_client.py b/lagent/actions/mcp_client.py new file mode 100644 index 0000000..dfd1e56 --- /dev/null +++ b/lagent/actions/mcp_client.py @@ -0,0 +1,409 @@ +import asyncio +import logging +import random +import threading +import time +from collections import deque +from contextlib import AsyncExitStack, nullcontext +from typing import Deque, Literal, Optional, TypeAlias + +from lagent.actions.base_action import AsyncActionMixin, BaseAction +from lagent.actions.parser import JsonParser, ParseError +from lagent.schema import ActionReturn, ActionStatusCode + +ServerType: TypeAlias = Literal["stdio", "sse", "http"] + +logger = logging.getLogger(__name__) +_loop = None + + +def _get_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except Exception: + logger.warning('Can not found event loop in current thread. Create a new event loop.') + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + if event_loop.is_running(): + global _loop + if _loop: + return _loop + + from threading import Thread + + def _start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + event_loop = asyncio.new_event_loop() + Thread(target=_start_loop, args=(event_loop,), daemon=True).start() + _loop = event_loop + return event_loop + + +logger = logging.getLogger(__file__) + + +class TokenBucket: + def __init__(self, rate_limit: float): + self.rate_limit = rate_limit # tokens per second + self.tokens = rate_limit + self.last_update = time.time() + self.lock = threading.Lock() + + def acquire(self) -> bool: + with self.lock: + now = time.time() + # Add new tokens based on time elapsed + new_tokens = (now - self.last_update) * self.rate_limit + self.tokens = min(self.rate_limit, self.tokens + new_tokens) + self.last_update = now + + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + +class AsyncTokenBucket: + def __init__(self, rate_limit: float): + self.rate_limit = rate_limit + self.capacity = rate_limit + self.tokens = rate_limit + self.last_update = time.monotonic() + self._lock = asyncio.Lock() + + def _refill(self): + now = time.monotonic() + elapsed = now - self.last_update + if elapsed <= 0: + return + self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_limit) + self.last_update = now + + async def acquire(self): + while True: + async with self._lock: + self._refill() + if self.tokens >= 1: + self.tokens -= 1 + return + missing = 1 - self.tokens + wait_time = missing / self.rate_limit + await asyncio.sleep(wait_time) + + +class FairAsyncTokenBucket: + def __init__(self, rate_limit: float, capacity: Optional[float] = None): + """ + rate_limit: 每秒生成多少个 token + capacity: 桶容量(最大可累积多少 token),默认和 rate_limit 一样 + """ + self.rate_limit = float(rate_limit) + self.capacity = float(capacity) if capacity is not None else float(rate_limit) + + self.tokens = self.capacity + self.last_update = time.monotonic() + + self._lock = asyncio.Lock() + self._waiters: Deque[asyncio.Future] = deque() + self._drainer_running = False # 是否已有后台协程在发 token + + # ---------- 内部工具方法 ---------- + + def _refill_unlocked(self) -> None: + """ + 在不持锁的前提下不要调用。 + 根据时间流逝计算当前 token 数。 + """ + now = time.monotonic() + elapsed = now - self.last_update + if elapsed <= 0: + return + self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_limit) + self.last_update = now + + async def _drain_waiters(self) -> None: + """ + 后台协程:按 FIFO 顺序给排队的协程发 token。 + - 没 token 时,就 sleep 到下一个 token 产生的时间点。 + - 有 token 且有排队,就唤醒队头的一个,再继续循环。 + """ + try: + while True: + fut_to_wake: Optional[asyncio.Future] = None + sleep_time: Optional[float] = None + + async with self._lock: + self._refill_unlocked() + + # 队列空了,没什么好做的了,退出 drainer + if not self._waiters: + self._drainer_running = False + return + + if self.tokens >= 1: + # 有 token,按 FIFO 唤醒一个排队的协程 + self.tokens -= 1 + fut_to_wake = self._waiters.popleft() + sleep_time = 0.0 + else: + # 没 token,算一下距离下一个 token 的时间 + missing = 1.0 - self.tokens # 还差多少 token 才能发下一枚 + sleep_time = max(0.0, missing / self.rate_limit) + + # 出锁之后再唤醒,避免在锁里执行用户代码 / 回调 + if fut_to_wake is not None and not fut_to_wake.done(): + fut_to_wake.set_result(None) + + # 如果刚刚唤醒了一个协程,立刻回到循环,看是否还能继续发 + if sleep_time == 0.0: + continue + + # 没 token,就等到有 token 再继续 + await asyncio.sleep(sleep_time) + finally: + # 兜底,避免异常时 drainer_running 一直是 True 导致无法重启 + async with self._lock: + self._drainer_running = False + + # ---------- 对外接口 ---------- + + async def acquire(self) -> None: + """ + 获取一个 token(公平:排队 FIFO) + """ + loop = asyncio.get_running_loop() + + # 先尝试直接拿 token(快速路径) + async with self._lock: + self._refill_unlocked() + + # 如果有 token 且没有历史排队的协程,直接拿走返回 + if self.tokens >= 1 and not self._waiters: + self.tokens -= 1 + return + + # 否则需要排队 + fut = loop.create_future() + self._waiters.append(fut) + + # 启动 drainer(只要一个就够了) + if not self._drainer_running: + self._drainer_running = True + asyncio.create_task(self._drain_waiters()) + + # 等待被 drainer 唤醒,唤醒后说明自己拿到了 token + await fut + + +# --- 复用你原本的辅助工具 --- +_loop = None + + +def _get_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except Exception: + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + if event_loop.is_running(): + global _loop + if _loop: + return _loop + from threading import Thread + + def _start_loop(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + event_loop = asyncio.new_event_loop() + Thread(target=_start_loop, args=(event_loop,), daemon=True).start() + _loop = event_loop + return event_loop + + +class AsyncMCPClient(AsyncActionMixin, BaseAction): + """ + Standard Lagent Action that wraps a SINGLE tool from an MCP Server. + + Robustness Fix: + Creates a new connection for every request and closes it immediately after. + This prevents connection leaks and 'ConnectTimeout' in high-concurrency RL environments. + """ + + is_stateful = False + + def __init__( + self, + server_type: ServerType, + rate_limit: float = None, + max_concurrency: int = None, + # 注意:这里的 name 主要用于 Lagent 注册,但工具的实际元数据来自 MCP Server + name: Optional[str] = None, + **server_params, + ): + self._is_toolkit = False + self.server_type = server_type + self.server_params = server_params + + # 并发控制组件 + self.rate_limiter = FairAsyncTokenBucket(rate_limit) if rate_limit is not None else None + self._sem = asyncio.Semaphore(max_concurrency) if max_concurrency is not None else nullcontext() + + # 1. 临时连接获取工具元数据 (Metadata) + # 必须在 __init__ 完成,因为 Lagent 需要 self.description + loop = _get_event_loop() + if loop.is_running(): + fut = asyncio.run_coroutine_threadsafe(self._fetch_tool_metadata(), loop) + tools = fut.result() + else: + tools = loop.run_until_complete(self._fetch_tool_metadata()) + + # Single Action 约束:一个 Action 实例对应一个 MCP 工具 + if len(tools) != 1: + logger.warning( + f"MCP Server returned {len(tools)} tools, but AsyncMCPAction is designed for a Single Action. " + f"Using the first one: {tools[0].name}" + ) + + self.tool_info = tools[0] + tool_name = self.tool_info.name + + # 2. 初始化父类 BaseAction + super().__init__( + description={ + 'name': tool_name, + 'description': self.tool_info.description, + 'parameters': [ + {'name': k, 'type': v['type'].upper(), 'description': v.get('description', '')} + for k, v in self.tool_info.inputSchema['properties'].items() + ], + 'required': self.tool_info.inputSchema.get('required', []), + }, + parser=JsonParser, + ) + self._is_toolkit = False + + async def _connect(self, stack: AsyncExitStack): + """ + 内部辅助:建立连接并注册关闭回调。 + 所有网络资源都注册到 `stack` 中,确保自动释放。 + """ + from mcp import ClientSession, StdioServerParameters + + # --- Transport Layer --- + if self.server_type == "stdio": + from mcp.client.stdio import stdio_client + + logger.info( + f"Connecting to stdio MCP server with command: {self.server_params['command']} " + f"{self.server_params.get('args', [])}" + ) + client_kwargs = {"command": self.server_params["command"]} + for key in ["args", "env", "cwd"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + + server_params_obj = StdioServerParameters(**client_kwargs) + read, write = await stack.enter_async_context(stdio_client(server_params_obj)) + + elif self.server_type == "sse": + from mcp.client.sse import sse_client + + logger.info(f"Connecting to SSE MCP server at: {self.server_params['url']}") + + url = self.server_params["url"] + target_url = random.choice(url) if isinstance(url, list) else url + + client_kwargs = {"url": target_url} + for key in ["headers", "timeout", "sse_read_timeout"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + + read, write = await stack.enter_async_context(sse_client(**client_kwargs)) + + elif self.server_type == "http": + from mcp.client.streamable_http import streamablehttp_client + + logger.info(f"Connecting to StreamableHTTP MCP server at: {self.server_params['url']}") + + url = self.server_params["url"] + target_url = random.choice(url) if isinstance(url, list) else url + + client_kwargs = {"url": target_url} + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: + if self.server_params.get(key) is not None: + client_kwargs[key] = self.server_params[key] + + read, write, _ = await stack.enter_async_context(streamablehttp_client(**client_kwargs)) + + else: + raise ValueError(f"Unsupported server type: {self.server_type}") + + # --- Protocol Layer --- + session = await stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + return session + + async def _fetch_tool_metadata(self): + """在 init 阶段使用一次性连接获取工具定义""" + async with AsyncExitStack() as stack: + session = await self._connect(stack) + result = await session.list_tools() + return result.tools + + async def run(self, **kwargs) -> ActionReturn: + """ + Standard Lagent Action Entrypoint. + """ + fallback_args = kwargs.copy() + + try: + # 1. 并发/速率控制 + async with self._sem: + if self.rate_limiter is not None: + await self.rate_limiter.acquire() + + # 2. 执行逻辑 (Critical Resource Scope) + # 使用 AsyncExitStack 确保本次请求结束后,HTTP连接/进程管道被彻底关闭 + async with AsyncExitStack() as stack: + session = await self._connect(stack) + + # 调用 MCP 工具 + # 注意:Lagent 传入的是 kwargs 字典,MCP call_tool 正好接受字典 + outputs_obj = await session.call_tool(self.tool_info.name, kwargs) + + # 提取文本结果 + if outputs_obj.content and hasattr(outputs_obj.content[0], 'text'): + outputs = outputs_obj.content[0].text + else: + outputs = str(outputs_obj) + + except ParseError as exc: + return ActionReturn(fallback_args, type=self.name, errmsg=exc.err_msg, state=ActionStatusCode.ARGS_ERROR) + except Exception as exc: + # 记录详细堆栈以便调试 RL 过程中的错误 + logger.warning(f"MCP Action {self.name} failed: {exc}") + return ActionReturn(fallback_args, type=self.name, errmsg=str(exc), state=ActionStatusCode.API_ERROR) + + # 3. 结果封装 + if isinstance(outputs, ActionReturn): + action_return = outputs + if not action_return.args: + action_return.args = kwargs + if not action_return.type: + action_return.type = self.name + else: + # 尝试使用 JsonParser 解析结果(如果 MCP 返回的是 JSON 字符串) + # 否则直接作为字符串返回 + try: + result = self._parser.parse_outputs(outputs) + except: + result = str(outputs) + + action_return = ActionReturn(fallback_args, type=self.name, result=result) + + return action_return diff --git a/lagent/actions/web_visitor.py b/lagent/actions/web_visitor.py new file mode 100644 index 0000000..ff8f1ea --- /dev/null +++ b/lagent/actions/web_visitor.py @@ -0,0 +1,194 @@ +import asyncio +import json +import re +import traceback +import warnings +from typing import Any, List + +from transformers import AutoTokenizer + +from lagent.actions import AsyncActionMixin, BaseAction +from lagent.schema import ActionStatusCode, ActionValidCode, AgentMessage +from lagent.utils import create_object + + +def extract_last_json(text: str) -> dict | None: + """ + Extracts the last valid JSON object from a string. + Handles Markdown code blocks (```json ... ```) and raw JSON strings. + """ + try: + # 1. Try to find JSON within Markdown code blocks first + # Look for ```json ... ``` or just ``` ... ``` + code_block_pattern = re.compile(r'```(?:json)?\s*(\{.*?\})\s*```', re.DOTALL) + matches = code_block_pattern.findall(text) + if matches: + return json.loads(matches[-1]) + + # 2. If no code blocks, try to find the last outermost pair of braces + # This regex looks for { ... } lazily but we want the last one. + # A simple approach for nested JSON is tricky with regex, + # so we scan from right to left for the last '}' and find its matching '{'. + + stack, end_idx = 0, -1 + # Reverse search to find the last valid JSON structure + for i in range(len(text) - 1, -1, -1): + char = text[i] + if char == '}': + if stack == 0: + end_idx = i + stack += 1 + elif char == '{': + if stack > 0: + stack -= 1 + if stack == 0 and end_idx != -1: + # Found a potential outermost JSON object + candidate = text[i : end_idx + 1] + try: + return json.loads(candidate) + except json.JSONDecodeError: + # If this chunk isn't valid, reset and keep searching backwards + # (or you might decide to stop here depending on strictness) + stack, end_idx = 0, -1 + return None + except Exception: + return None + + +class WebVisitor(AsyncActionMixin, BaseAction): + + EXTRACTION_PROMPT = """Please process the following webpage content and user goal to extract relevant information: + +## **Webpage Content** +{webpage_content} + +## **User Goal** +{goal} + +## **Task Guidelines** +1. **Content Scanning for Rationale**: Locate the **specific sections/data** directly related to the user's goal within the webpage content +2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs. +3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal. + +**Final Output Format using JSON format has "rational", "evidence", "summary" feilds** +""" + + def __init__( + self, + browse_tool: BaseAction | dict, + llm: Any, + max_browse_attempts: int = 3, + max_extract_attempts: int = 3, + sleep_interval: int = 3, + truncate_browse_response_length: int | None = None, + tokenizer_path: str | None = None, + name: str = 'visit', + ): + super().__init__( + description={ + 'name': name, + 'description': 'Visit webpage(s) and return the summary of the content.', + 'parameters': [ + { + 'name': 'url', + 'type': ['STRING', 'ARRAY'], + "items": {"type": "string"}, + "minItems": 1, + 'description': 'The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs.', + }, + {'name': 'goal', 'type': 'STRING', 'description': 'The goal of the visit for webpage(s).'}, + ], + 'required': ['url', 'goal'], + } + ) + browse_tool = create_object(browse_tool) + assert not browse_tool.is_toolkit and browse_tool.description['required'] == [ + 'url' + ], "browse_tool must be a single-tool action with only 'url' as required argument." + self.browse_tool = browse_tool + self.llm = create_object(llm) + self.max_browse_attempts = max_browse_attempts + self.max_extract_attempts = max_extract_attempts + self.sleep_interval = sleep_interval + self.truncate_browse_response_length = truncate_browse_response_length + self.tokenizer = ( + AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) if tokenizer_path else None + ) + if self.truncate_browse_response_length is not None and self.tokenizer is None: + warnings.warn( + 'truncate_browse_response_length is set but tokenizer_path is not provided. ' + 'The raw webpage content will be truncated by characters instead of tokens.' + ) + + async def run(self, url: str | List[str], goal: str) -> str: + if isinstance(url, str): + url = [url] + + async def _inner_call(single_url: str) -> str: + try: + return await self._read_webpage(single_url, goal) + except Exception as e: + return f"Error fetching {single_url}: {str(e)}" + + response = await asyncio.gather(*[_inner_call(single_url) for single_url in url]) + return "\n=======\n".join(response).strip() + + async def _read_webpage(self, url: str, goal: str) -> str: + tool_response = compressed = None + return_template = ( + f"The useful information in {url} for user goal {goal} as follows: \n\n" + f"Evidence in page: \n{{evidence}}\n\nSummary: \n{{summary}}\n\n" + ) + for _ in range(self.max_browse_attempts): + resp = await self.browse_tool({'url': url}) + if resp.valid == ActionValidCode.OPEN and resp.state == ActionStatusCode.SUCCESS: + tool_response = resp.format_result() + break + await asyncio.sleep(self.sleep_interval) + else: + return return_template.format( + evidence="The provided webpage content could not be accessed. Please check the URL or file format.", + summary="The webpage content could not be processed, and therefore, no information is available.", + ) + + if self.truncate_browse_response_length is not None: + tool_response = ( + self.tokenizer.decode( + self.tokenizer.encode( + tool_response, + max_length=self.truncate_browse_response_length, + truncation=True, + add_special_tokens=False, + ) + ) + if self.tokenizer is not None + else tool_response[: self.truncate_browse_response_length] + ) + + for _ in range(self.max_extract_attempts): + try: + prompt = self.EXTRACTION_PROMPT.format(webpage_content=tool_response, goal=goal) + llm_response = await self.llm.chat([{'role': 'user', 'content': prompt}]) + if llm_response and not isinstance(llm_response, str): + llm_response = ( + llm_response.content + if isinstance(llm_response, AgentMessage) + else llm_response.choices[0].message.content + ) + if not llm_response or len(llm_response) < 10: + tool_response = tool_response[: int(len(tool_response) * 0.7)] + continue + compressed = extract_last_json(llm_response) + if isinstance(compressed, dict) and all( + key in compressed for key in ['rational', 'evidence', 'summary'] + ): + break + except Exception: + print(f"Error in extracting information: {traceback.format_exc()}") + await asyncio.sleep(self.sleep_interval) + else: + return return_template.format( + evidence="Failed to extract relevant information from the webpage content.", + summary="The webpage content could not be processed, and therefore, no information is available.", + ) + return return_template.format(evidence=compressed['evidence'], summary=compressed['summary']) diff --git a/lagent/agents/agent.py b/lagent/agents/agent.py index ae99d6d..40271ce 100644 --- a/lagent/agents/agent.py +++ b/lagent/agents/agent.py @@ -73,7 +73,10 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa self.update_memory(message, session_id=session_id) response_message = self.forward(*message, session_id=session_id, **kwargs) if not isinstance(response_message, AgentMessage): - response_message = AgentMessage(sender=self.name, content=response_message) + if isinstance(response_message, str): + response_message = AgentMessage(sender=self.name, content=response_message) + else: + response_message = AgentMessage.from_model_response(response_message, self.name) self.update_memory(response_message, session_id=session_id) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): @@ -158,6 +161,28 @@ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = F for agent in getattr(self, '_agents', {}).values(): agent.reset(session_id, recursive=True) + def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict]: + """Get OpenAI format messages from memory. + + Args: + session_id (int): The session id of the memory. + keypath (Optional[str]): The keypath of the sub-agent to get messages from. Default is None. + + Returns: + List[dict]: The messages from the memory including the sub-agent's system prompt. + """ + if keypath: + keys, agent = keypath.split('.'), self + for key in keys: + agents = getattr(agent, '_agents', {}) + if key not in agents: + raise KeyError(f'No sub-agent named {key} in {agent}') + agent = agents[key] + return agent.get_messages(session_id=session_id) + if self.aggregator: + return self.aggregator.aggregate(self.memory.get(session_id), self.name, self.output_format, self.template) + raise ValueError(f'{self.name} has no aggregator to get messages') + def __repr__(self): def _rcsv_repr(agent, n_indent=1): @@ -186,7 +211,10 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen self.update_memory(message, session_id=session_id) response_message = await self.forward(*message, session_id=session_id, **kwargs) if not isinstance(response_message, AgentMessage): - response_message = AgentMessage(sender=self.name, content=response_message) + if isinstance(response_message, str): + response_message = AgentMessage(sender=self.name, content=response_message) + else: + response_message = AgentMessage.from_model_response(response_message, self.name) self.update_memory(response_message, session_id=session_id) response_message = copy.deepcopy(response_message) for hook in self._hooks.values(): diff --git a/lagent/agents/aggregator/default_aggregator.py b/lagent/agents/aggregator/default_aggregator.py index 0888aef..720355c 100644 --- a/lagent/agents/aggregator/default_aggregator.py +++ b/lagent/agents/aggregator/default_aggregator.py @@ -1,31 +1,42 @@ -from typing import Dict, List +from typing import List from lagent.memory import Memory from lagent.prompts import StrParser +from lagent.schema import ActionReturn class DefaultAggregator: - def aggregate(self, - messages: Memory, - name: str, - parser: StrParser = None, - system_instruction: str = None) -> List[Dict[str, str]]: + def aggregate(self, messages: Memory, name: str, parser: StrParser = None, system_instruction=None) -> List[dict]: _message = [] messages = messages.get_memory() if system_instruction: - _message.extend( - self.aggregate_system_intruction(system_instruction)) + _message.extend(self.aggregate_system_intruction(system_instruction)) for message in messages: if message.sender == name: - _message.append( - dict(role='assistant', content=str(message.content))) + _message.append(message.to_model_request()) else: - user_message = message.content - if len(_message) > 0 and _message[-1]['role'] == 'user': - _message[-1]['content'] += user_message + user_message, extra_info = message.content, message.extra_info + if isinstance(user_message, list): + for m in user_message: + if isinstance(m, dict): + m = ActionReturn(**m) + assert isinstance(m, ActionReturn), f"Expected m to be ActionReturn, but got {type(m)}" + _message.append( + dict( + role='tool', + tool_call_id=m.tool_call_id, + content=m.format_result(), + name=m.type, + extra_info=extra_info, + ) + ) else: - _message.append(dict(role='user', content=user_message)) + if len(_message) > 0 and _message[-1]['role'] == 'user': + _message[-1]['content'] += user_message + _message[-1]['extra_info'] = extra_info + else: + _message.append(dict(role='user', content=user_message, extra_info=extra_info)) return _message @staticmethod @@ -39,6 +50,5 @@ def aggregate_system_intruction(system_intruction) -> List[dict]: if not isinstance(msg, dict): raise TypeError(f'Unsupported message type: {type(msg)}') if not ('role' in msg and 'content' in msg): - raise KeyError( - f"Missing required key 'role' or 'content': {msg}") + raise KeyError(f"Missing required key 'role' or 'content': {msg}") return system_intruction diff --git a/lagent/agents/fc_agent.py b/lagent/agents/fc_agent.py new file mode 100644 index 0000000..f3437a8 --- /dev/null +++ b/lagent/agents/fc_agent.py @@ -0,0 +1,160 @@ +import asyncio +import json +from copy import deepcopy +from dataclasses import asdict +from typing import Dict, List, Literal, Optional, Union + +from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed + +from lagent.actions import AsyncActionExecutor +from lagent.hooks import Hook +from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode, AgentMessage, AgentStatusCode +from lagent.utils import create_object, truncate_text +from .agent import AsyncAgent + +DEFAULT_TOOL_TEMPLATE = """# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{tools} + + +For each function call, return a json object with function name and arguments within XML tags: + +{{"name": , "arguments": }} +""" + + +def get_tool_prompt(actions: list, exclude_arguments: list = None, template: str = DEFAULT_TOOL_TEMPLATE) -> str: + exclude_arguments = exclude_arguments or ['session_id'] + + def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') -> dict: + properties = {} + for param in action_description['parameters']: + param = deepcopy(param) + param_name, param_type = param.pop('name'), param.pop('type') + if param_name in exclude_arguments: + continue + param_type = [t.lower() for t in param_type] if isinstance(param_type, list) else param_type.lower() + properties[param_name] = {'type': param_type, **param} + return { + 'type': 'function', + 'function': { + 'name': name_pattern.format(action_description['name']), + 'description': action_description['description'], + 'parameters': {'type': 'object', 'properties': properties, 'required': action_description['required']}, + }, + } + + tools = [] + for action in actions if isinstance(actions, list) else [actions]: + action = create_object(action) + action_desc = action.description + if action.is_toolkit: + for api in action_desc['api_list']: + tools.append(_convert_tool_schema(api, f"{action.name}.{{}}")) + else: + tools.append(_convert_tool_schema(action_desc)) + return template.format(tools='\n'.join([json.dumps(tool, ensure_ascii=False) for tool in tools])) + + +class FunctionCallAgent(AsyncAgent): + def __init__( + self, + select_agent: Union[Dict, AsyncAgent], + env_agent: Union[Dict, AsyncAgent], + finish_condition: callable = lambda x, _: x and not x.tool_calls, + max_turn: Optional[int] = None, + name: Optional[str] = None, + ): + super().__init__(name=name) + self.select_agent = create_object(select_agent) + self.env_agent = create_object(env_agent) + self.finish_condition = finish_condition + self.max_turn = max_turn + + async def forward(self, env_message: AgentMessage, session_id: str | int, **kwargs): + selection_message: AgentMessage = None + current_turn = 0 + while ( + self.finish_condition is None + or not self.finish_condition(selection_message, env_message) + and (self.max_turn is None or current_turn < self.max_turn) + ): + selection_message = await self.select_agent(env_message, session_id=session_id, **kwargs) + if selection_message.stream_state == AgentStatusCode.SERVER_ERR: + raise ValueError("Rollout response error: state is neither completed nor aborted!") + if selection_message.stream_state == AgentStatusCode.SESSION_OUT_OF_LIMIT: + for _ in range(2): # remove the last two messages + self.select_agent.memory.get(session_id).delete(-1) + return AgentMessage(role='env', content='Context length exceeds the limit') + env_message = await self.env_agent(selection_message, session_id=session_id) + current_turn += 1 + return AgentMessage(role="env", content="Finished") + + +class EnvAgent(AsyncAgent): + def __init__( + self, + actions: list, + stateful_tools: List[str] = None, + max_tool_response_length: int = None, + tool_response_truncate_side: Literal['left', 'right', 'middle'] = 'middle', + action_hooks: List[Union[dict, Hook]] = None, + name: Optional[str] = None, + ): + super().__init__(name=name) + self.actions = AsyncActionExecutor(actions, hooks=action_hooks) + self.stateful_tools = stateful_tools or [] + self.max_tool_response_length = max_tool_response_length + self.tool_response_truncate_side = tool_response_truncate_side + self._retry_mechanism = retry( + stop=stop_after_attempt(3), + wait=wait_fixed(2), + retry=retry_if_result( + lambda r: r.valid == ActionValidCode.OPEN + and r.state not in [ActionStatusCode.SUCCESS, ActionStatusCode.ARGS_ERROR] + ), + retry_error_callback=lambda retry_state: retry_state.outcome.result(), + ) + + async def forward(self, selection_message: AgentMessage, session_id: str | int, **kwargs): + if not selection_message.tool_calls: + return AgentMessage(sender=self.name, content='No tool call') + + tool_responses = await asyncio.gather( + *[ + self._retry_mechanism(self.execute_tool)(tool_call, session_id) + for tool_call in selection_message.tool_calls + ] + ) + for tool_call_id, tool_response in zip(selection_message.tool_calls_ids, tool_responses): + tool_response.tool_call_id = tool_call_id + res = tool_response.format_result() + if self.max_tool_response_length is not None and len(res) > self.max_tool_response_length: + res = truncate_text(res, max_num=self.max_tool_response_length, side=self.tool_response_truncate_side) + tool_response.result = [{'type': 'text', 'content': res}] + return AgentMessage(sender=self.name, content=[asdict(resp) for resp in tool_responses]) + + async def execute_tool(self, tool_call: dict, session_id: str | int) -> ActionReturn: + try: + if tool_call['name'].split('.', 1)[0] not in self.actions: + return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Tool {tool_call["name"]} Not Found') + if isinstance(tool_call['arguments'], str): + tool_call['arguments'] = json.loads(tool_call['arguments']) + if tool_call['name'] in self.stateful_tools: + tool_call = deepcopy(tool_call) + tool_call['arguments']['session_id'] = session_id + except Exception as e: + return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Invalid tool call format: {str(e)}') + tool_response: ActionReturn = ( + await self.actions( + AgentMessage( + sender='assistant', content=dict(name=tool_call['name'], parameters=tool_call['arguments']) + ), + session_id=session_id, + ) + ).content + return tool_response diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7418a65..e83b7eb 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -1,8 +1,10 @@ import asyncio import json import os +import random import time import traceback +import uuid import warnings from concurrent.futures import ThreadPoolExecutor from logging import getLogger @@ -10,7 +12,12 @@ from typing import AsyncGenerator, Dict, List, Optional, Union import aiohttp +import httpx import requests +from openai import NOT_GIVEN, APITimeoutError, AsyncOpenAI +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage from ..schema import ModelStatusCode from ..utils import filter_suffix @@ -308,7 +315,9 @@ def streaming(raw_response): response = dict() try: - raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True) + raw_response = requests.post( + self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True + ) return streaming(raw_response) except requests.ConnectionError: errmsg = 'Got connection error ' + str(traceback.format_exc()) @@ -809,3 +818,79 @@ def tokenize(self, prompt: str) -> list: self.tiktoken = tiktoken enc = self.tiktoken.encoding_for_model(self.model_type) return enc.encode(prompt) + + +class AsyncOpenAIWrapper: + def __init__( + self, + model: str, + base_url: str | List[str], + api_key: str = None, + sample_params: dict = None, + proxy: str = None, + timeout: int = 600, + max_retry: int = 50, + sleep_interval: int = 5, + extra_body: dict = None, + ): + self.model = model + http_client = httpx.AsyncClient(proxy=proxy, timeout=timeout) if proxy else None + self.clients = [ + AsyncOpenAI(api_key=api_key, base_url=url, http_client=http_client) + for url in (base_url if isinstance(base_url, list) else [base_url]) + ] + self.sample_params = sample_params or {} + self.timeout = timeout + self.max_retry = max_retry + self.sleep_interval = sleep_interval + self.extra_body = extra_body + + async def chat(self, messages: list[dict], session_id: str | int = None, **kwargs) -> ChatCompletion: + fallback_response = ChatCompletion( + id=f'chatcmpl-{uuid.uuid4()}', + object='chat.completion', + created=int(time.time()), + model=self.model, + choices=[ + Choice(message=ChatCompletionMessage(role='assistant', content=''), finish_reason='stop', index=0) + ], + ) + for attempt in range(self.max_retry): + try: + client = random.choice(self.clients) + response = await client.chat.completions.create( + model=self.model, + messages=messages, + stream=False, + temperature=kwargs.get('temperature') or self.sample_params.get("temperature", 0.7), + top_p=kwargs.get('top_p') or self.sample_params.get("top_p", 1.0), + timeout=self.timeout, + extra_body=self.extra_body, + max_tokens=kwargs.get('max_tokens') or self.sample_params.get("max_tokens", 64 * 1024), + reasoning_effort=kwargs.get('reasoning_effort') + or self.sample_params.get("reasoning_effort", NOT_GIVEN), + ) + return response + except (APITimeoutError, TimeoutError) as e: + print(f"LLM Call Timeout: {e}") + if attempt == self.max_retry - 1: + return fallback_response + await asyncio.sleep(self.sleep_interval) + except Exception as e: + for val in [ + "用户额度不足", + "剩余额度", + "TimeoutError", + "litellm.BadRequestError", + "litellm.APIError: APIError", + "Call `/v1/models` to view available models", + ]: + if val in str(e): + print(f"LLM Call Error: {e}") + if attempt == self.max_retry - 1: + return fallback_response + await asyncio.sleep(self.sleep_interval) + break + else: + return fallback_response + return fallback_response diff --git a/lagent/memory/base_memory.py b/lagent/memory/base_memory.py index 3c8fcf0..f7d72be 100644 --- a/lagent/memory/base_memory.py +++ b/lagent/memory/base_memory.py @@ -1,12 +1,11 @@ from typing import Callable, Dict, List, Optional, Union from lagent.schema import AgentMessage +from lagent.utils import load_class_from_string class Memory: - _item_cls = AgentMessage - def __init__(self, recent_n=None) -> None: self.memory: List[AgentMessage] = [] self.recent_n = recent_n @@ -25,39 +24,43 @@ def get_memory( memory = [m for i, m in enumerate(memory) if filter_func(i, m)] return memory - def add(self, memories: Union[List[Dict], Dict, None]) -> None: + def add(self, memories: Union[List[AgentMessage | str], AgentMessage, str]) -> None: for memory in memories if isinstance(memories, (list, tuple)) else [memories]: if isinstance(memory, str): - memory = self._item_cls(sender='user', content=memory) + memory = AgentMessage(sender='user', content=memory) if isinstance(memory, AgentMessage): - if not isinstance(memory, self._item_cls): - memory = self._item_cls.model_validate(memory, from_attributes=True) self.memory.append(memory) - def delete(self, index: Union[List, int]) -> None: + def delete(self, index: Union[List[int], int]) -> None: if isinstance(index, int): del self.memory[index] else: for i in index: del self.memory[i] - def load( - self, - memories: Union[str, Dict, List], - overwrite: bool = True, - ) -> None: + def load(self, memories: Union[str, dict, List], overwrite: bool = True) -> None: if overwrite: self.memory = [] if isinstance(memories, dict): - self.memory.append(self._item_cls.model_validate(memories)) + memories = memories.copy() + _cls = ( + load_class_from_string(memories.pop('__model_spec__')) + if '__model_spec__' in memories + else AgentMessage + ) + self.memory.append(_cls.model_validate(memories)) elif isinstance(memories, list): for m in memories: - self.memory.append(self._item_cls.model_validate(m)) + m = m.copy() + _cls = load_class_from_string(m.pop('__model_spec__')) if '__model_spec__' in m else AgentMessage + self.memory.append(_cls.model_validate(m)) else: raise TypeError(f'{type(memories)} is not supported') def save(self) -> List[dict]: memory = [] for m in self.memory: - memory.append(m.model_dump()) + m_dumped = m.model_dump() + m_dumped['__model_spec__'] = f'{m.__module__}.{m.__class__.__name__}' + memory.append(m_dumped) return memory diff --git a/lagent/schema.py b/lagent/schema.py index 668846f..ca2d135 100644 --- a/lagent/schema.py +++ b/lagent/schema.py @@ -2,12 +2,12 @@ from enum import IntEnum from typing import Any, Dict, List, Optional, Union +from openai.types.chat import ChatCompletion from pydantic import BaseModel def enum_dict_factory(inputs): - inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i - for i in inputs] + inputs = [(i[0], i[-1].value) if isinstance(i[-1], IntEnum) else i for i in inputs] return dict(inputs) @@ -47,6 +47,7 @@ class ActionReturn: state: Union[ActionStatusCode, int] = ActionStatusCode.SUCCESS thought: Optional[str] = None valid: Optional[ActionValidCode] = ActionValidCode.OPEN + tool_call_id: Optional[str] = None def format_result(self) -> str: """Concatenate items in result.""" @@ -89,9 +90,47 @@ class AgentStatusCode(IntEnum): class AgentMessage(BaseModel): content: Any + thinking: Optional[str] = None sender: str = 'user' + tool_calls: Optional[List[dict]] = None + tool_calls_ids: Optional[List[str]] = None formatted: Optional[Any] = None extra_info: Optional[Any] = None type: Optional[str] = None receiver: Optional[str] = None stream_state: Union[ModelStatusCode, AgentStatusCode] = AgentStatusCode.END + finish_reason: Optional[str] = None + + @classmethod + def from_model_response(cls, response: ChatCompletion, sender: str) -> "AgentMessage": + """Convert model response dict to AgentMessage.""" + chat_message = response.choices[0].message + tool_calls = chat_message.tool_calls and [tool_call.model_dump() for tool_call in chat_message.tool_calls] + return cls( + sender=sender, + content=chat_message.content or "", + thinking=getattr(chat_message, 'reasoning_content', None), + tool_calls=[tool_call['function'] for tool_call in tool_calls] if tool_calls else None, + tool_calls_ids=[tool_call['id'] for tool_call in tool_calls] if tool_calls else None, + stream_state=( + ModelStatusCode.SESSION_OUT_OF_LIMIT + if response.choices[0].finish_reason == 'length' + else ModelStatusCode.END + ), + finish_reason=response.choices[0].finish_reason, + ) + + def to_model_request(self, role: str = 'assistant') -> dict: + """Convert AgentMessage to model request dict.""" + tool_calls = [ + {'id': tool_call_id, 'function': tool_call, 'type': 'function'} + for tool_call, tool_call_id in zip(self.tool_calls or [], self.tool_calls_ids or []) + ] + return { + "role": role, + "content": self.content, + "reasoning_content": self.thinking, + "tool_calls": tool_calls if tool_calls else None, + "extra_info": self.extra_info, + "stream_state": self.stream_state, + } diff --git a/lagent/utils/__init__.py b/lagent/utils/__init__.py index a0ac549..64abd61 100644 --- a/lagent/utils/__init__.py +++ b/lagent/utils/__init__.py @@ -6,9 +6,16 @@ filter_suffix, get_logger, load_class_from_string, + truncate_text, ) __all__ = [ - 'is_module_exist', 'filter_suffix', 'create_object', 'get_logger', - 'load_class_from_string', 'async_as_completed', 'GeneratorWithReturn' + 'is_module_exist', + 'filter_suffix', + 'create_object', + 'get_logger', + 'load_class_from_string', + 'async_as_completed', + 'GeneratorWithReturn', + 'truncate_text', ] diff --git a/lagent/utils/util.py b/lagent/utils/util.py index 609382e..9b913b2 100644 --- a/lagent/utils/util.py +++ b/lagent/utils/util.py @@ -4,11 +4,12 @@ import logging import os import os.path as osp +import re import sys import time from functools import partial from logging.handlers import RotatingFileHandler -from typing import Any, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast def load_class_from_string(class_path: str, path=None): @@ -33,6 +34,8 @@ def create_object(config: Union[Dict, Any] = None): preserved key to indicate the class (path). When accepting non-dictionary input, the function degenerates to an identity. """ + from ray.actor import ActorClass + if config is None or not isinstance(config, dict): return config assert isinstance(config, dict) and 'type' in config @@ -41,7 +44,9 @@ def create_object(config: Union[Dict, Any] = None): obj_type = config.pop('type') if isinstance(obj_type, str): obj_type = load_class_from_string(obj_type) - if inspect.isclass(obj_type): + if isinstance(obj_type, ActorClass): + obj = cast(ActorClass, obj_type).remote(**config) + elif inspect.isclass(obj_type): obj = obj_type(**config) else: assert callable(obj_type) @@ -123,6 +128,79 @@ def get_logger( return logger +def truncate_text(text, max_num=4000, side='middle'): + """ + 中英文混合场景下,根据 side 参数截断文本。总共保留 approx max_num 个“词/字”。 + + 定义“单位”逻辑: + 1. 连续的英文/数字被视为 1 个单位 (如 "Python", "123") + 2. 单个汉字或标点被视为 1 个单位 (如 "中", "。", ",") + + Args: + text (str): 原始文本 + max_num (int): 截取的单位数量 + side (str): 截断模式,可选 'left', 'right', 'middle' + 'left': 保留尾部(截断头部) + 'right': 保留头部(截断尾部) + 'middle': 保留头尾(截断中间) + """ + if not text or max_num <= 0: + return "" + + # --- 核心正则 --- + # 逻辑:匹配 (英文/数字/下划线/连字符 组成的词) 或 (非空白的单字符) + # 注意:英文的正则必须放在前面,表示优先匹配完整单词 + pattern = re.compile(r"[a-zA-Z0-9_'-]+|[^\s]") + + # 获取所有匹配对象(包含位置信息) + matches = list(pattern.finditer(text)) + total_units = len(matches) + + # 如果总数不够,返回全文 + if total_units <= max_num: + return text + + parts = [] + + if side == 'left': + # 保留尾部 max_num 个单位(截断头部) + # matches[-max_num] 是保留部分的第一个词 + start_idx = total_units - max_num + start_pos = matches[start_idx].start() + parts.append("(truncated)...") + parts.append(text[start_pos:]) + + elif side == 'right': + # 保留头部 max_num 个单位(截断尾部) + # matches[max_num - 1] 是保留部分的最后一个词 + end_pos = matches[max_num - 1].end() + parts.append(text[:end_pos]) + parts.append("...(truncated)") + + else: # middle + # --- 智能截取 (保留头尾) --- + head_count = max_num // 2 + tail_count = max_num - head_count + + # 1. 提取头部 + if head_count > 0: + # matches[head_count - 1] 是头部想要保留的最后一个词 + head_span_end = matches[head_count - 1].end() + parts.append(text[:head_span_end]) + + # 2. 插入截断提示 + parts.append("...(truncated)...") + + # 3. 提取尾部 + if tail_count > 0: + # matches[-tail_count] 是尾部想要保留的第一个词 + tail_idx = total_units - tail_count + tail_span_start = matches[tail_idx].start() + parts.append(text[tail_span_start:]) + + return "".join(parts) + + class GeneratorWithReturn: """Generator wrapper to capture the return value."""