-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathgateway.py
More file actions
1385 lines (1115 loc) · 48.7 KB
/
gateway.py
File metadata and controls
1385 lines (1115 loc) · 48.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""MiniCPMO45 推理 Gateway
请求分发网关,不加载模型,负责:
- 路由 Chat/Streaming/Duplex 请求到 Worker
- 会话映射和 KV Cache LRU 命中路由
- 统一 FIFO 请求排队(容量 1000,位置追踪 + ETA 估算)
- Worker 健康检查
启动方式:
cd /user/sunweiyue/lib/swy-dev/minicpmo45_service
PYTHONPATH=. .venv/base/bin/python gateway.py \\
--port 10024 \\
--workers localhost:22400,localhost:22401
"""
import os
import re
import json
import asyncio
import argparse
import logging
import time
from typing import Optional, List, Dict, Any
from datetime import datetime
from contextlib import asynccontextmanager
import zipfile
from io import BytesIO
import httpx
import numpy as np
import uvicorn
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request, UploadFile, File
from fastapi.responses import HTMLResponse, FileResponse, StreamingResponse, Response, RedirectResponse
from fastapi.staticfiles import StaticFiles
from gateway_modules.models import (
GatewayWorkerStatus,
ServiceStatus,
WorkersResponse,
QueueStatus,
EtaConfig,
EtaStatus,
)
from gateway_modules.worker_pool import WorkerPool, WorkerConnection
from gateway_modules.ref_audio_registry import (
RefAudioRegistry,
RefAudioListResponse,
UploadRefAudioRequest,
RefAudioResponse,
)
from gateway_modules.app_registry import (
AppRegistry,
AppToggleRequest,
AppsPublicResponse,
AppsAdminResponse,
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("gateway")
_SESSION_ID_RE = re.compile(r'^[a-zA-Z0-9_\-]+$')
def _sanitize_session_id(session_id: str) -> str:
"""校验 session_id 只含安全字符,防止 path traversal"""
if not _SESSION_ID_RE.match(session_id):
safe = re.sub(r'[^a-zA-Z0-9_\-]', '_', session_id)
return safe
return session_id
# ============ 全局变量 ============
worker_pool: Optional[WorkerPool] = None
ref_audio_registry: Optional[RefAudioRegistry] = None
app_registry: AppRegistry = AppRegistry()
# 配置(通过 main() 传入)
GATEWAY_CONFIG: Dict[str, Any] = {}
# ============ 应用初始化 ============
_cleanup_task: Optional[asyncio.Task] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期"""
global worker_pool, ref_audio_registry, _cleanup_task
workers = GATEWAY_CONFIG.get("workers", ["localhost:10031"])
max_queue = GATEWAY_CONFIG.get("max_queue_size", 1000)
timeout = GATEWAY_CONFIG.get("timeout", 300.0)
# 从 config 读取 ETA 参数
eta_config_data = GATEWAY_CONFIG.get("eta_config")
eta_config = EtaConfig(**eta_config_data) if eta_config_data else EtaConfig()
worker_pool = WorkerPool(
worker_addresses=workers,
max_queue_size=max_queue,
request_timeout=timeout,
eta_config=eta_config,
ema_alpha=GATEWAY_CONFIG.get("eta_ema_alpha", 0.3),
ema_min_samples=GATEWAY_CONFIG.get("eta_ema_min_samples", 3),
)
await worker_pool.start()
# 初始化参考音频注册表
data_dir = os.path.join(os.path.dirname(__file__), "data", "assets", "ref_audio")
ref_audio_registry = RefAudioRegistry(storage_dir=data_dir)
# 启动 session 清理后台任务(每天一次)
_cleanup_task = asyncio.create_task(_session_cleanup_loop())
logger.info(f"Gateway started, {len(worker_pool.workers)} workers, {ref_audio_registry.count} ref audios")
yield
if _cleanup_task:
_cleanup_task.cancel()
await worker_pool.stop()
logger.info("Gateway stopped")
async def _session_cleanup_loop() -> None:
"""每天执行一次 session 清理(retention_days 和 max_storage_gb 都为 -1 时不执行)"""
from session_cleanup import cleanup_sessions
from config import get_config
await asyncio.sleep(60)
while True:
try:
cfg = get_config()
days = cfg.recording.session_retention_days
gb = cfg.recording.max_storage_gb
if days < 0 and gb < 0:
logger.info("[Cleanup] Disabled (retention_days=-1, max_storage_gb=-1), sleeping")
else:
report = await asyncio.to_thread(
cleanup_sessions, cfg.data_dir, days, gb,
)
logger.info(f"[Cleanup] {report}")
except Exception as e:
logger.error(f"[Cleanup] Failed: {e}", exc_info=True)
await asyncio.sleep(86400)
app = FastAPI(
title="MiniCPMO45 Gateway",
description="MiniCPMO45 多模态推理网关",
version="1.0.0-alpha.2",
lifespan=lifespan,
)
# ============ 健康检查 ============
@app.get("/health")
async def health():
"""健康检查"""
return {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
}
@app.get("/status", response_model=ServiceStatus)
async def status():
"""服务状态"""
if worker_pool is None:
raise HTTPException(status_code=503, detail="Service not ready")
return ServiceStatus(
gateway_healthy=True,
total_workers=len(worker_pool.workers),
idle_workers=worker_pool.idle_count,
busy_workers=worker_pool.busy_count,
duplex_workers=worker_pool.duplex_count,
loading_workers=worker_pool.loading_count,
error_workers=worker_pool.error_count,
offline_workers=worker_pool.offline_count,
queue_length=worker_pool.queue_length,
max_queue_size=worker_pool.max_queue_size,
running_tasks=worker_pool._get_running_tasks(),
)
@app.get("/workers", response_model=WorkersResponse)
async def list_workers():
"""Worker 列表"""
if worker_pool is None:
raise HTTPException(status_code=503, detail="Service not ready")
return WorkersResponse(
total=len(worker_pool.workers),
workers=worker_pool.get_all_workers(),
)
# ============ Chat API(无状态,HTTP 代理到 Worker) ============
@app.post("/api/chat")
async def chat(request: Request):
"""Chat 推理
无状态,路由到任意空闲 Worker。
如果无空闲 Worker,入 FIFO 队列等待。
"""
if not app_registry.is_enabled("turnbased"):
raise HTTPException(status_code=403, detail="Turn-based Chat is currently disabled")
if worker_pool is None:
raise HTTPException(status_code=503, detail="Service not ready")
request_body = await request.json()
queue_start = datetime.now()
# 入队(如果有空闲 Worker 会立即分配)
try:
ticket, future = worker_pool.enqueue("chat")
except WorkerPool.QueueFullError:
raise HTTPException(
status_code=503,
detail=f"Queue full ({worker_pool.max_queue_size} requests)",
)
# 等待 Worker 分配(同时检测客户端断开)
worker: Optional[WorkerConnection] = None
try:
if future.done():
worker = future.result()
else:
# 排队等待,定期检查客户端是否断开
while not future.done():
if await request.is_disconnected():
worker_pool.cancel(ticket.ticket_id)
return # 客户端已断开
try:
worker = await asyncio.wait_for(
asyncio.shield(future), timeout=2.0
)
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
raise HTTPException(status_code=503, detail="Request cancelled")
if worker is None and future.done():
worker = future.result()
except asyncio.CancelledError:
raise HTTPException(status_code=503, detail="Request cancelled")
if worker is None:
raise HTTPException(status_code=503, detail="No worker available")
# 标记 Worker 为 BUSY
queue_done_time = datetime.now()
queue_wait_ms = (queue_done_time - queue_start).total_seconds() * 1000
estimated_queue_s = ticket.estimated_wait_s
worker.mark_busy(GatewayWorkerStatus.BUSY_CHAT, "chat")
task_start = datetime.now()
try:
async with httpx.AsyncClient(timeout=worker_pool.request_timeout) as client:
resp = await client.post(
f"{worker.url}/chat",
json=request_body,
timeout=worker_pool.request_timeout,
)
worker.total_requests += 1
worker.last_heartbeat = datetime.now()
result = resp.json()
result["queue_wait_ms"] = round(queue_wait_ms)
result["estimated_queue_wait_s"] = round(estimated_queue_s, 1)
return result
except Exception as e:
logger.error(f"[{ticket.ticket_id}] Chat request failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
finally:
duration = (datetime.now() - task_start).total_seconds()
worker_pool.release_worker(worker, request_type="chat", duration_s=duration)
# ============ Chat WebSocket 代理 ============
@app.websocket("/ws/chat")
async def chat_ws_proxy(ws: WebSocket):
"""Chat WebSocket 代理 — 排队后透传到 Worker /ws/chat"""
if not app_registry.is_enabled("turnbased"):
await ws.close(code=1008, reason="Turn-based Chat is currently disabled")
return
if worker_pool is None:
await ws.close(code=1013, reason="Service not ready")
return
await ws.accept()
assigned_worker: Optional[WorkerConnection] = None
worker_ws = None
task_start: Optional[datetime] = None
try:
# 收到前端发来的请求消息
raw = await ws.receive_text()
# 排队获取 Worker
try:
ticket, future = worker_pool.enqueue("chat")
except WorkerPool.QueueFullError:
await ws.send_json({"type": "error", "error": "Queue full"})
return
# 等待 Worker 分配
while not future.done():
try:
assigned_worker = await asyncio.wait_for(asyncio.shield(future), timeout=2.0)
break
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
await ws.send_json({"type": "error", "error": "Cancelled"})
return
if assigned_worker is None and future.done():
assigned_worker = future.result()
if assigned_worker is None:
await ws.send_json({"type": "error", "error": "No worker available"})
return
assigned_worker.mark_busy(GatewayWorkerStatus.BUSY_CHAT, "chat_ws")
task_start = datetime.now()
# 连接 Worker WebSocket
import websockets
ws_url = f"ws://{assigned_worker.host}:{assigned_worker.port}/ws/chat"
worker_ws = await websockets.connect(ws_url)
# 转发请求
await worker_ws.send(raw)
# 透传 Worker 的所有响应到前端
async for msg_data in worker_ws:
await ws.send_text(msg_data)
except WebSocketDisconnect:
logger.info("Chat WS proxy: client disconnected")
except Exception as e:
logger.error(f"Chat WS proxy error: {e}", exc_info=True)
try:
await ws.send_json({"type": "error", "error": str(e)})
except Exception:
pass
finally:
if worker_ws:
try:
await worker_ws.close()
except Exception:
pass
if assigned_worker and task_start:
duration = (datetime.now() - task_start).total_seconds()
worker_pool.release_worker(assigned_worker, request_type="chat_ws", duration_s=duration)
try:
await ws.close()
except Exception:
pass
# ============ Half-Duplex WebSocket(独占 Worker,FIFO 排队 + 代理到 Worker) ============
@app.websocket("/ws/half_duplex/{session_id}")
async def half_duplex_ws(ws: WebSocket, session_id: str):
"""Half-Duplex WebSocket 代理
独占一个 Worker,直到用户停止或会话超时(默认 3 分钟)。
前端发送音频 chunk,Worker 用 VAD 检测语音后 prefill + generate。
"""
if not app_registry.is_enabled("half_duplex_audio"):
await ws.close(code=1008, reason="Half-Duplex Audio is currently disabled")
return
if worker_pool is None:
await ws.close(code=1013, reason="Service not ready")
return
session_id = _sanitize_session_id(session_id)
await ws.accept()
try:
ticket, future = worker_pool.enqueue("half_duplex_audio", session_id=session_id)
except WorkerPool.QueueFullError:
await ws.send_json({
"type": "error",
"error": f"Queue full ({worker_pool.max_queue_size} requests)",
})
await ws.close(code=1013, reason="Queue full")
return
worker: Optional[WorkerConnection] = None
if future.done():
worker = future.result()
else:
try:
await ws.send_json({
"type": "queued",
"position": ticket.position,
"estimated_wait_s": ticket.estimated_wait_s,
"ticket_id": ticket.ticket_id,
"queue_length": worker_pool.queue_length,
})
while not future.done():
try:
worker = await asyncio.wait_for(
asyncio.shield(future), timeout=3.0
)
break
except asyncio.TimeoutError:
updated = worker_pool.get_ticket(ticket.ticket_id)
if updated:
await ws.send_json({
"type": "queue_update",
"position": updated.position,
"estimated_wait_s": updated.estimated_wait_s,
"queue_length": worker_pool.queue_length,
})
except asyncio.CancelledError:
worker_pool.cancel(ticket.ticket_id)
return
except (WebSocketDisconnect, Exception) as e:
logger.info(f"Half-Duplex WS disconnected during queue: session={session_id} ({e})")
worker_pool.cancel(ticket.ticket_id)
return
if worker is None and future.done():
worker = future.result()
if worker is None:
await ws.send_json({"type": "error", "error": "No worker available"})
await ws.close(code=1013, reason="No worker available")
return
await ws.send_json({"type": "queue_done"})
logger.info(f"Half-Duplex WS connected: session={session_id} → {worker.worker_id}")
worker.mark_busy(GatewayWorkerStatus.BUSY_HALF_DUPLEX, "half_duplex_audio", session_id=session_id)
task_start = datetime.now()
worker_ws = None
try:
import websockets
ws_url = f"ws://{worker.host}:{worker.port}/ws/half_duplex?session_id={session_id}"
max_retries = 5
for attempt in range(max_retries):
try:
worker_ws = await websockets.connect(ws_url, open_timeout=5)
break
except Exception as conn_err:
if attempt < max_retries - 1:
logger.warning(
f"Half-Duplex WS connect to {worker.worker_id} failed (attempt {attempt + 1}): "
f"{conn_err}, retrying in 1s..."
)
await asyncio.sleep(1.0)
else:
raise
async def client_to_worker():
try:
async for raw in ws.iter_text():
await worker_ws.send(raw)
except WebSocketDisconnect:
pass
async def worker_to_client():
try:
async for raw in worker_ws:
await ws.send_text(raw)
except Exception:
pass
done, pending = await asyncio.wait(
[
asyncio.create_task(client_to_worker()),
asyncio.create_task(worker_to_client()),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
except Exception as e:
logger.error(f"Half-Duplex WS error: {e}", exc_info=True)
finally:
if worker_ws:
try:
await worker_ws.close()
except Exception:
pass
if worker:
duration = (datetime.now() - task_start).total_seconds() if task_start else 0
worker_pool.release_worker(worker, request_type="half_duplex_audio", duration_s=duration)
logger.info(f"Half-Duplex WS ended: session={session_id}, Worker released ({duration:.1f}s)")
# ============ 前端诊断日志写入 ============
async def _write_diagnostic(path: str, msg: dict) -> None:
"""将前端上报的诊断数据追加写入 JSONL 文件(异步,不阻塞事件循环)"""
msg["_server_recv_ts"] = time.time()
line = json.dumps(msg, ensure_ascii=False) + "\n"
try:
await asyncio.to_thread(_sync_append, path, line)
except Exception as e:
logger.warning(f"Failed to write diagnostic: {e}")
def _sync_append(path: str, line: str) -> None:
with open(path, "a", encoding="utf-8") as f:
f.write(line)
# ============ Duplex WebSocket(有状态,FIFO 排队 + 代理到 Worker) ============
@app.websocket("/ws/duplex/{session_id}")
async def duplex_ws(ws: WebSocket, session_id: str):
"""Duplex WebSocket 代理
先 accept WS(以便推送排队状态),然后入 FIFO 队列等待 Worker。
Duplex 独占一个 Worker,直到用户挂断或暂停超时。
"""
duplex_app = "audio_duplex" if session_id.startswith("adx_") else "omni"
if not app_registry.is_enabled(duplex_app):
await ws.close(code=1008, reason=f"{duplex_app} is currently disabled")
return
if worker_pool is None:
await ws.close(code=1013, reason="Service not ready")
return
session_id = _sanitize_session_id(session_id)
# 先 accept,这样排队期间可以推送状态
await ws.accept()
# 入队
try:
duplex_type = "audio_duplex" if session_id.startswith("adx_") else "omni_duplex"
ticket, future = worker_pool.enqueue(duplex_type, session_id=session_id)
except WorkerPool.QueueFullError:
await ws.send_json({
"type": "error",
"error": f"Queue full ({worker_pool.max_queue_size} requests)",
})
await ws.close(code=1013, reason="Queue full")
return
# 等待 Worker 分配(排队期间检测前端断连,断连时取消 ticket)
worker: Optional[WorkerConnection] = None
if future.done():
worker = future.result()
else:
try:
await ws.send_json({
"type": "queued",
"position": ticket.position,
"estimated_wait_s": ticket.estimated_wait_s,
"ticket_id": ticket.ticket_id,
"queue_length": worker_pool.queue_length,
})
while not future.done():
try:
worker = await asyncio.wait_for(
asyncio.shield(future), timeout=3.0
)
break
except asyncio.TimeoutError:
updated = worker_pool.get_ticket(ticket.ticket_id)
if updated:
await ws.send_json({
"type": "queue_update",
"position": updated.position,
"estimated_wait_s": updated.estimated_wait_s,
"queue_length": worker_pool.queue_length,
})
except asyncio.CancelledError:
worker_pool.cancel(ticket.ticket_id)
return
except (WebSocketDisconnect, Exception) as e:
logger.info(f"Duplex WS disconnected during queue wait: session={session_id}, cancelling ticket {ticket.ticket_id} ({e})")
worker_pool.cancel(ticket.ticket_id)
return
if worker is None and future.done():
worker = future.result()
if worker is None:
await ws.send_json({"type": "error", "error": "No worker available"})
await ws.close(code=1013, reason="No worker available")
return
# 通知前端排队完成
await ws.send_json({"type": "queue_done"})
logger.info(f"Duplex WS connected: session={session_id} → {worker.worker_id}")
worker.mark_busy(GatewayWorkerStatus.DUPLEX_ACTIVE, duplex_type, session_id=session_id)
task_start = datetime.now()
worker_ws = None
try:
import websockets
ws_url = f"ws://{worker.host}:{worker.port}/ws/duplex?session_id={session_id}"
# Worker 可能在清理上一个 Duplex session(GPU 显存释放等),
# 短暂重试确保 Worker 准备就绪
max_retries = 5
for attempt in range(max_retries):
try:
worker_ws = await websockets.connect(ws_url, open_timeout=5)
break
except Exception as conn_err:
if attempt < max_retries - 1:
logger.warning(
f"Duplex WS connect to {worker.worker_id} failed (attempt {attempt + 1}): "
f"{conn_err}, retrying in 1s..."
)
await asyncio.sleep(1.0)
else:
raise
diag_log_path = os.path.join("tmp", f"diag_{session_id}.jsonl")
async def client_to_worker():
"""Client → Worker"""
try:
async for raw in ws.iter_text():
msg = json.loads(raw)
if msg.get("type") == "client_diagnostic":
await _write_diagnostic(diag_log_path, msg)
continue
if msg.get("type") == "pause":
worker.update_duplex_status(GatewayWorkerStatus.DUPLEX_PAUSED)
elif msg.get("type") == "resume":
worker.update_duplex_status(GatewayWorkerStatus.DUPLEX_ACTIVE)
elif msg.get("type") == "stop":
pass
await worker_ws.send(raw)
except WebSocketDisconnect:
pass
async def worker_to_client():
"""Worker → Client"""
try:
async for raw in worker_ws:
await ws.send_text(raw)
except Exception:
pass
done, pending = await asyncio.wait(
[
asyncio.create_task(client_to_worker()),
asyncio.create_task(worker_to_client()),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
except Exception as e:
logger.error(f"Duplex WS error: {e}", exc_info=True)
finally:
if worker_ws:
try:
await worker_ws.close()
except Exception:
pass
if worker:
duration = (datetime.now() - task_start).total_seconds() if task_start else 0
worker_pool.release_worker(worker, request_type=duplex_type, duration_s=duration)
logger.info(f"Duplex WS ended: session={session_id}, type={duplex_type}, Worker released ({duration:.1f}s)")
# ============ 默认 Ref Audio 分发 ============
@app.get("/api/frontend_defaults")
async def get_frontend_defaults():
"""返回前端页面需要的默认配置
前端页面加载时调用此接口获取 playback_delay_ms 等可配置的默认值,
避免前端硬编码。返回值来自 config.json。
"""
from config import get_config
return get_config().frontend_defaults()
# ============ System Prompt 预设 ============
_presets_cache: Optional[Dict[str, List[Dict[str, Any]]]] = None
def _get_audio_meta(rel_path: str, project_root: str) -> Dict[str, Any]:
"""获取音频文件的元数据(不加载 base64),用于预设列表"""
import librosa
if not rel_path:
return {"name": "", "duration": 0}
abs_path = rel_path if os.path.isabs(rel_path) else os.path.join(project_root, rel_path)
name = os.path.basename(abs_path)
if not os.path.exists(abs_path):
return {"name": name, "duration": 0}
try:
audio, sr = librosa.load(abs_path, sr=16000, mono=True)
return {"name": name, "duration": round(len(audio) / sr, 1)}
except Exception:
return {"name": name, "duration": 0}
def _load_audio_base64(rel_path: str, project_root: str) -> Optional[Dict[str, Any]]:
"""加载音频文件为 base64(按需调用)"""
import librosa
if not rel_path:
return None
abs_path = rel_path if os.path.isabs(rel_path) else os.path.join(project_root, rel_path)
if not os.path.exists(abs_path):
return None
try:
audio, sr = librosa.load(abs_path, sr=16000, mono=True)
audio_bytes = audio.astype(np.float32).tobytes()
import base64 as b64mod
return {
"data": b64mod.b64encode(audio_bytes).decode("ascii"),
"name": os.path.basename(abs_path),
"duration": round(len(audio) / sr, 1),
}
except Exception as e:
logger.error(f"Failed to load audio {abs_path}: {e}")
return None
def _load_presets_from_dir(project_root: str) -> Dict[str, List[Dict[str, Any]]]:
"""扫描 assets/presets/<mode>/*.yaml,返回元数据(不含音频 base64)"""
import yaml
presets_root = os.path.join(project_root, "assets", "presets")
result: Dict[str, List[Dict[str, Any]]] = {}
if not os.path.isdir(presets_root):
return result
for mode_dir in sorted(os.listdir(presets_root)):
mode_path = os.path.join(presets_root, mode_dir)
if not os.path.isdir(mode_path):
continue
mode_presets = []
for fname in sorted(os.listdir(mode_path)):
if not fname.endswith((".yaml", ".yml")):
continue
fpath = os.path.join(mode_path, fname)
try:
with open(fpath, "r", encoding="utf-8") as f:
preset = yaml.safe_load(f)
if not preset or not isinstance(preset, dict):
continue
if "system_content" in preset:
resolved = []
for item in preset["system_content"]:
if item.get("type") == "audio" and item.get("path"):
meta = _get_audio_meta(item["path"], project_root)
resolved.append({
"type": "audio",
"data": None,
"path": item["path"],
"name": meta["name"],
"duration": meta["duration"],
})
else:
resolved.append(item)
preset["system_content"] = resolved
if "ref_audio_path" in preset:
meta = _get_audio_meta(preset["ref_audio_path"], project_root)
preset["ref_audio"] = {
"data": None,
"path": preset["ref_audio_path"],
"name": meta["name"],
"duration": meta["duration"],
}
del preset["ref_audio_path"]
mode_presets.append(preset)
except Exception as e:
logger.error(f"Failed to load preset {fpath}: {e}")
if mode_presets:
mode_presets.sort(key=lambda p: p.get("order", 999))
result[mode_dir] = mode_presets
total = sum(len(v) for v in result.values())
logger.info(f"Loaded {total} presets (metadata only) across {len(result)} modes")
return result
@app.get("/api/presets")
async def get_presets():
"""返回预设元数据(不含音频 base64,音频通过 /api/presets/{mode}/{id}/audio 按需加载)"""
global _presets_cache
if _presets_cache is not None:
return _presets_cache
project_root = os.path.dirname(__file__)
_presets_cache = _load_presets_from_dir(project_root)
return _presets_cache
@app.get("/api/presets/{mode}/{preset_id}/audio")
async def get_preset_audio(mode: str, preset_id: str):
"""按需加载单个 preset 的音频数据"""
global _presets_cache
if _presets_cache is None:
project_root = os.path.dirname(__file__)
_presets_cache = _load_presets_from_dir(project_root)
mode_presets = _presets_cache.get(mode, [])
preset = next((p for p in mode_presets if p.get("id") == preset_id), None)
if not preset:
raise HTTPException(status_code=404, detail=f"Preset not found: {mode}/{preset_id}")
project_root = os.path.dirname(__file__)
result: Dict[str, Any] = {}
if "system_content" in preset:
audio_items = []
for item in preset["system_content"]:
if item.get("type") == "audio" and item.get("path"):
loaded = _load_audio_base64(item["path"], project_root)
audio_items.append(loaded or {"data": None, "name": item.get("name", ""), "duration": 0})
result["system_content_audio"] = audio_items
if preset.get("ref_audio") and preset["ref_audio"].get("path"):
loaded = _load_audio_base64(preset["ref_audio"]["path"], project_root)
result["ref_audio"] = loaded or {"data": None, "name": preset["ref_audio"].get("name", ""), "duration": 0}
return result
# 缓存:启动后首次请求时加载,之后直接返回
_default_ref_audio_cache: Optional[Dict[str, Any]] = None
@app.get("/api/default_ref_audio")
async def get_default_ref_audio():
"""返回默认参考音频(PCM float32 16kHz mono base64)
前端页面加载时调用此接口获取默认 ref audio,
之后所有请求统一通过 ref_audio_base64 传递音频数据。
"""
global _default_ref_audio_cache
if _default_ref_audio_cache is not None:
return _default_ref_audio_cache
from config import get_config
cfg = get_config()
if not cfg.ref_audio_path:
raise HTTPException(status_code=404, detail="No default ref audio configured")
# 解析路径(支持相对路径,相对于 minicpmo45_service/)
ref_path = cfg.ref_audio_path
if not os.path.isabs(ref_path):
ref_path = os.path.join(os.path.dirname(__file__), ref_path)
if not os.path.exists(ref_path):
raise HTTPException(status_code=404, detail=f"Default ref audio not found: {cfg.ref_audio_path}")
try:
import base64
import librosa
import numpy as np
# 加载并重采样为 16kHz mono float32(与前端上传格式一致)
audio, sr = librosa.load(ref_path, sr=16000, mono=True)
duration = len(audio) / 16000
# 转换为 base64(PCM float32)
audio_bytes = audio.astype(np.float32).tobytes()
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
_default_ref_audio_cache = {
"name": os.path.basename(cfg.ref_audio_path),
"duration": round(duration, 1),
"sample_rate": 16000,
"samples": len(audio),
"base64": audio_b64,
}
logger.info(
f"Default ref audio loaded: {_default_ref_audio_cache['name']} "
f"({duration:.1f}s, {len(audio)} samples)"
)
return _default_ref_audio_cache
except Exception as e:
logger.error(f"Failed to load default ref audio: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to load ref audio: {e}")
# ============ 素材管理 API ============
@app.get("/api/assets/ref_audio", response_model=RefAudioListResponse)
async def list_ref_audios():
"""列出参考音频"""
if ref_audio_registry is None:
raise HTTPException(status_code=503, detail="Service not ready")
return RefAudioListResponse(
total=ref_audio_registry.count,
ref_audios=ref_audio_registry.list_all(),
)
@app.post("/api/assets/ref_audio", response_model=RefAudioResponse)
async def upload_ref_audio(request: UploadRefAudioRequest):
"""上传参考音频"""
if ref_audio_registry is None:
raise HTTPException(status_code=503, detail="Service not ready")
try:
info = ref_audio_registry.upload(
name=request.name,
audio_base64=request.audio_base64,
)
return RefAudioResponse(
success=True,
id=info.id,
name=info.name,
message=f"Uploaded successfully, duration={info.duration_ms}ms",
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Upload ref audio failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/api/assets/ref_audio/{ref_id}", response_model=RefAudioResponse)
async def delete_ref_audio(ref_id: str):
"""删除参考音频"""
if ref_audio_registry is None:
raise HTTPException(status_code=503, detail="Service not ready")
if not ref_audio_registry.exists(ref_id):
raise HTTPException(status_code=404, detail=f"Ref audio not found: {ref_id}")
success = ref_audio_registry.delete(ref_id)
return RefAudioResponse(
success=success,
id=ref_id,
message="Deleted" if success else "Failed to delete",
)
# ============ 队列状态 API ============
@app.get("/api/queue", response_model=QueueStatus)
async def get_queue():
"""获取当前队列状态"""
if worker_pool is None:
raise HTTPException(status_code=503, detail="Service not ready")
return worker_pool.get_queue_status()
@app.get("/api/queue/{ticket_id}")
async def get_queue_ticket(ticket_id: str):
"""获取指定排队项的状态(前端轮询用)"""
if worker_pool is None:
raise HTTPException(status_code=503, detail="Service not ready")
ticket = worker_pool.get_ticket(ticket_id)