Skip to content

Commit f67b899

Browse files
committed
fix: Enhance Wecom adapter with improved message handling and error resilience
- Modify message content generation to add line breaks for text messages - Improve JSON encoding error handling with logging in memory persistence - Refactor Wecom adapter to support both enterprise and standard WeChat APIs - Add support for duplicate message handling and passive reply mechanism - Improve error logging and server startup/shutdown processes - Simplify server task cancellation in WebServer
1 parent 154c7bc commit f67b899

File tree

4 files changed

+107
-54
lines changed

4 files changed

+107
-54
lines changed

kirara_ai/im/message.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,12 @@ def __repr__(self):
333333
@property
334334
def content(self) -> str:
335335
"""获取消息的纯文本内容"""
336-
return "".join([element.to_plain() for element in self.message_elements])
336+
content = ""
337+
for element in self.message_elements:
338+
content += element.to_plain()
339+
if isinstance(element, TextMessage):
340+
content += "\n"
341+
return content.strip()
337342

338343
@property
339344
def images(self) -> List[ImageMessage]:

kirara_ai/memory/persistences/codecs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from types import FunctionType
44

55
from kirara_ai.im.sender import ChatSender, ChatType
6+
from kirara_ai.logger import get_logger
67

78

89
class MemoryJSONEncoder(json.JSONEncoder):
@@ -29,7 +30,11 @@ def default(self, obj):
2930
"kwdefaults": obj.__kwdefaults__,
3031
"doc": obj.__doc__,
3132
}
32-
return super().default(obj)
33+
try:
34+
return super().default(obj)
35+
except Exception as e:
36+
get_logger("MemoryJSONEncoder").warning(f"failed to encode object: {e}")
37+
return None
3338

3439

3540
def memory_json_decoder(obj):

kirara_ai/plugins/im_wecom_adapter/adapter.py

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,27 @@
1-
import uuid
2-
from typing import Any, Optional
3-
4-
from kirara_ai.im.sender import ChatSender
5-
from kirara_ai.web.app import WebServer
6-
from kirara_ai.workflow.core.dispatch.dispatcher import WorkflowDispatcher
7-
8-
# 兼容新旧版本的 wechatpy 导入
9-
try:
10-
from wechatpy.enterprise import parse_message
11-
from wechatpy.enterprise.client import WeChatClient
12-
from wechatpy.enterprise.crypto import WeChatCrypto
13-
from wechatpy.enterprise.exceptions import InvalidCorpIdException
14-
except ImportError:
15-
from wechatpy.work.crypto import WeChatCrypto
16-
from wechatpy.work.client import WeChatClient
17-
from wechatpy.work.exceptions import InvalidCorpIdException
18-
from wechatpy.work import parse_message
19-
201
import asyncio
212
import base64
223
import os
4+
import uuid
235
from io import BytesIO
6+
from typing import Any, Optional
247

258
import aiohttp
269
from fastapi import FastAPI, HTTPException, Request, Response
2710
from pydantic import BaseModel, ConfigDict, Field
2811
from wechatpy.exceptions import InvalidSignatureException
12+
from wechatpy.replies import create_reply
2913

3014
from kirara_ai.im.adapter import IMAdapter
3115
from kirara_ai.im.message import FileElement, ImageMessage, IMMessage, TextMessage, VideoElement, VoiceMessage
16+
from kirara_ai.im.sender import ChatSender
3217
from kirara_ai.logger import HypercornLoggerWrapper, get_logger
18+
from kirara_ai.web.app import WebServer
19+
from kirara_ai.workflow.core.dispatch.dispatcher import WorkflowDispatcher
3320

3421
WECOM_TEMP_DIR = os.path.join(os.getcwd(), 'data', 'temp', 'wecom')
3522

3623
WEBHOOK_URL_PREFIX = "/im/webhook/wechat"
3724

38-
3925
def make_webhook_url():
4026
return f"{WEBHOOK_URL_PREFIX}/{str(uuid.uuid4())[:8]}"
4127

@@ -120,18 +106,15 @@ class WecomAdapter(IMAdapter):
120106

121107
dispatcher: WorkflowDispatcher
122108
web_server: WebServer
123-
109+
124110
def __init__(self, config: WecomConfig):
125111
self.config = config
126112
if self.config.host:
127113
self.app = FastAPI()
128114
else:
129115
self.app = self.web_server.app
130-
131-
self.crypto = WeChatCrypto(
132-
config.token, config.encoding_aes_key, config.corp_id or config.app_id
133-
)
134-
self.client = WeChatClient(config.corp_id, config.secret)
116+
117+
self.setup_wechat_api()
135118
self.logger = get_logger("Wecom-Adapter")
136119
self.is_running = False
137120
if not self.config.host:
@@ -141,35 +124,70 @@ def __init__(self, config: WecomConfig):
141124
self.config.port = 15650
142125
if not self.config.webhook_url:
143126
self.config.webhook_url = make_webhook_url()
144-
127+
128+
self.reply_tasks = {}
129+
130+
def setup_wechat_api(self):
131+
if self.config.corp_id:
132+
from wechatpy.enterprise import parse_message
133+
from wechatpy.enterprise.client import WeChatClient
134+
from wechatpy.enterprise.crypto import WeChatCrypto
135+
self.crypto = WeChatCrypto(
136+
self.config.token, self.config.encoding_aes_key, self.config.corp_id
137+
)
138+
self.client = WeChatClient(self.config.corp_id, self.config.secret)
139+
self.parse_message = parse_message
140+
else:
141+
from wechatpy import WeChatClient
142+
from wechatpy.crypto import WeChatCrypto
143+
from wechatpy.parser import parse_message
144+
self.crypto = WeChatCrypto(
145+
self.config.token, self.config.encoding_aes_key, self.config.app_id
146+
)
147+
self.client = WeChatClient(self.config.app_id, self.config.secret)
148+
self.parse_message = parse_message
149+
145150
def setup_routes(self):
146-
if "host" in self.config.__pydantic_extra__:
151+
if self.config.host:
147152
webhook_url = '/wechat'
148153
else:
149154
webhook_url = self.config.webhook_url
150-
155+
# unregister old route if exists
156+
for route in self.app.routes:
157+
if route.path == webhook_url:
158+
self.app.routes.remove(route)
159+
151160
@self.app.get(webhook_url)
152161
async def handle_check_request(request: Request):
153162
"""处理 GET 请求"""
154163
if not self.is_running:
164+
self.logger.warning("Wecom-Adapter is not running, skipping check request.")
155165
raise HTTPException(status_code=404)
156166

157167
signature = request.query_params.get("msg_signature", "")
158168
timestamp = request.query_params.get("timestamp", "")
159169
nonce = request.query_params.get("nonce", "")
160170
echo_str = request.query_params.get("echostr", "")
161-
try:
162-
echo_str = self.crypto.check_signature(
163-
signature, timestamp, nonce, echo_str
164-
)
171+
172+
173+
try:
174+
if self.config.corp_id:
175+
echo_str = self.crypto.check_signature(
176+
signature, timestamp, nonce, echo_str
177+
)
178+
else:
179+
from wechatpy.utils import check_signature
180+
check_signature(self.config.token, signature, timestamp, nonce)
165181
return Response(content=echo_str, media_type="text/plain")
166182
except InvalidSignatureException:
183+
self.logger.error("failed to check signature, please check your settings.")
167184
raise HTTPException(status_code=403)
168185

169186
@self.app.post(webhook_url)
170187
async def handle_message(request: Request):
171188
"""处理 POST 请求"""
172189
if not self.is_running:
190+
self.logger.warning("Wecom-Adapter is not running, skipping message request.")
173191
raise HTTPException(status_code=404)
174192
signature = request.query_params.get("msg_signature", "")
175193
timestamp = request.query_params.get("timestamp", "")
@@ -178,10 +196,17 @@ async def handle_message(request: Request):
178196
msg = self.crypto.decrypt_message(
179197
await request.body(), signature, timestamp, nonce
180198
)
181-
except (InvalidSignatureException, InvalidCorpIdException):
199+
except InvalidSignatureException:
200+
self.logger.error("failed to check signature, please check your settings.")
182201
raise HTTPException(status_code=403)
183-
msg = parse_message(msg)
184-
202+
msg = self.parse_message(msg)
203+
204+
if msg.id in self.reply_tasks:
205+
self.logger.debug(f"skip processing due to duplicate msgid: {msg.id}")
206+
reply = await self.reply_tasks[msg.id]
207+
del self.reply_tasks[msg.id]
208+
return Response(content=create_reply(reply, msg).render(), media_type="text/xml")
209+
185210
# 预处理媒体消息
186211
media_path = None
187212
if msg.type in ["voice", "video", "file"]:
@@ -191,9 +216,13 @@ async def handle_message(request: Request):
191216

192217
# 转换消息
193218
message = self.convert_to_message(msg, media_path)
219+
self.reply_tasks[msg.id] = asyncio.Future()
220+
message.sender.raw_metadata["reply"] = self.reply_tasks[msg.id]
194221
# 分发消息
195-
await self.dispatcher.dispatch(self, message)
196-
return Response(content="ok", media_type="text/plain")
222+
asyncio.create_task(self.dispatcher.dispatch(self, message))
223+
reply = await message.sender.raw_metadata["reply"]
224+
del message.sender.raw_metadata["reply"]
225+
return Response(content=create_reply(reply, msg).render(), media_type="text/xml")
197226

198227
def convert_to_message(self, raw_message: Any, media_path: Optional[str] = None) -> IMMessage:
199228
"""将企业微信消息转换为统一消息格式"""
@@ -236,6 +265,7 @@ async def _send_text(self, user_id: str, text: str):
236265
return self.client.message.send_text(self.config.app_id, user_id, text)
237266
except Exception as e:
238267
self.logger.error(f"Failed to send text message: {e}")
268+
raise e
239269

240270
async def _send_media(self, user_id: str, media_data: str, media_type: str):
241271
"""发送媒体消息的通用方法"""
@@ -247,24 +277,33 @@ async def _send_media(self, user_id: str, media_data: str, media_type: str):
247277
return send_method(self.config.app, user_id, media_id)
248278
except Exception as e:
249279
self.logger.error(f"Failed to send {media_type} message: {e}")
280+
raise e
250281

251282
async def send_message(self, message: IMMessage, recipient: ChatSender):
252283
"""发送消息到企业微信"""
253284
user_id = recipient.user_id
254285
res = None
255-
for element in message.message_elements:
256-
if isinstance(element, TextMessage) and element.text:
257-
res = await self._send_text(user_id, element.text)
258-
elif isinstance(element, ImageMessage) and element.url:
259-
res = await self._send_media(user_id, element.url, "image")
260-
elif isinstance(element, VoiceMessage) and element.url:
261-
res = await self._send_media(user_id, element.url, "voice")
262-
elif isinstance(element, VideoElement) and element.file:
263-
res = await self._send_media(user_id, element.file, "video")
264-
elif isinstance(element, FileElement) and element.path:
265-
res = await self._send_media(user_id, element.path, "file")
266-
if res:
267-
print(res)
286+
287+
try:
288+
for element in message.message_elements:
289+
if isinstance(element, TextMessage) and element.text:
290+
res = await self._send_text(user_id, element.text)
291+
elif isinstance(element, ImageMessage) and element.url:
292+
res = await self._send_media(user_id, element.url, "image")
293+
elif isinstance(element, VoiceMessage) and element.url:
294+
res = await self._send_media(user_id, element.url, "voice")
295+
elif isinstance(element, VideoElement) and element.file:
296+
res = await self._send_media(user_id, element.file, "video")
297+
elif isinstance(element, FileElement) and element.path:
298+
res = await self._send_media(user_id, element.path, "file")
299+
if res:
300+
print(res)
301+
recipient.raw_metadata["reply"].set_result(None)
302+
except Exception as e:
303+
if 'Error code: 48001' in str(e):
304+
# 未开通主动回复能力
305+
self.logger.warning("未开通主动回复能力,将采用被动回复消息 API,此模式下只能回复一条消息。")
306+
recipient.raw_metadata["reply"].set_result(message.content)
268307

269308
async def _start_standalone_server(self):
270309
"""启动服务"""
@@ -294,12 +333,16 @@ async def _stop_standalone_server(self):
294333
self.logger.error(f"Error during server shutdown: {e}")
295334

296335
async def start(self):
336+
self.setup_wechat_api()
297337
if self.config.host:
298338
self.logger.warning("正在使用过时的启动模式,请尽快更新为 Webhook 模式。")
299339
await self._start_standalone_server()
300340
self.setup_routes()
341+
self.is_running = True
342+
self.logger.info("Wecom-Adapter 启动成功")
301343

302344
async def stop(self):
303345
if self.config.host:
304346
await self._stop_standalone_server()
305347
self.is_running = False
348+
self.logger.info("Wecom-Adapter 停止成功")

kirara_ai/web/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ async def start(self):
242242
async def stop(self):
243243
"""停止Web服务器"""
244244
self.shutdown_event.set()
245+
245246
if self.server_task:
246-
self.server_task.cancel()
247247
try:
248248
await self.server_task
249249
except asyncio.CancelledError:

0 commit comments

Comments
 (0)