|
16 | 16 | import time
|
17 | 17 | from abc import ABC, abstractmethod
|
18 | 18 | from concurrent.futures import ThreadPoolExecutor, as_completed
|
19 |
| -from typing import Dict, List, Optional, Tuple |
| 19 | +from typing import Dict, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | from langchain_core.language_models import BaseChatModel
|
22 | 22 | from langchain_core.messages import BaseMessage, HumanMessage
|
|
33 | 33 |
|
34 | 34 |
|
35 | 35 | class StateBasedConfig(BaseModel):
|
36 |
| - aggregators: Dict[str, List[BaseAggregator]] |
| 36 | + aggregators: Dict[Union[str, Tuple[str, str]], List[BaseAggregator]] = Field( |
| 37 | + description="Dict of topic : aggregator or (topic, msg_type) : aggragator" |
| 38 | + ) |
37 | 39 | time_interval: float = Field(default=5.0)
|
38 | 40 | max_workers: int = 8
|
39 | 41 |
|
@@ -96,9 +98,13 @@ def setup_connector(self) -> BaseConnector:
|
96 | 98 |
|
97 | 99 | def _configure_state_sources(self):
|
98 | 100 | for source, aggregators in self.config.aggregators.items():
|
| 101 | + if isinstance(source, tuple): |
| 102 | + source, msg_type = source |
| 103 | + else: |
| 104 | + msg_type = None |
99 | 105 | for aggregator in aggregators:
|
100 | 106 | callback_id = self._connector.register_callback(
|
101 |
| - source, aggregator, raw=True |
| 107 | + source, aggregator, raw=True, msg_type=msg_type |
102 | 108 | )
|
103 | 109 | self._registered_callbacks.add(callback_id)
|
104 | 110 |
|
|
0 commit comments