|
2 | 2 | import logging |
3 | 3 | import secrets |
4 | 4 | import threading |
| 5 | +import types |
5 | 6 | from typing import Generator, Sequence |
6 | 7 |
|
7 | 8 | from flask import ( |
|
36 | 37 | prefix_base_url, |
37 | 38 | serialize, |
38 | 39 | ) |
| 40 | +from mesop.utils.async_utils import run_async_generator, run_coroutine |
39 | 41 | from mesop.utils.url_utils import remove_url_query_param |
40 | 42 | from mesop.warn import warn |
41 | 43 |
|
|
44 | 46 | logger = logging.getLogger(__name__) |
45 | 47 |
|
46 | 48 |
|
| 49 | +def _process_on_load_result(result) -> Generator[None, None, None]: |
| 50 | + """Process on_load result, handling sync generators, async generators, and coroutines.""" |
| 51 | + if result is not None: |
| 52 | + if isinstance(result, types.AsyncGeneratorType): |
| 53 | + yield from run_async_generator(result) |
| 54 | + elif isinstance(result, types.CoroutineType): |
| 55 | + yield run_coroutine(result) |
| 56 | + else: |
| 57 | + # Regular generator |
| 58 | + yield from result |
| 59 | + |
| 60 | + |
47 | 61 | def configure_flask_app( |
48 | 62 | *, prod_mode: bool = True, exceptions_to_propagate: Sequence[type] = () |
49 | 63 | ) -> Flask: |
@@ -172,9 +186,9 @@ def generate_data(ui_request: pb.UiRequest) -> Generator[str, None, None]: |
172 | 186 | ) |
173 | 187 | ) |
174 | 188 | # on_load is a generator function then we need to iterate through |
175 | | - # the generator object. |
| 189 | + # the generator object. This also handles async generators and coroutines. |
176 | 190 | if result: |
177 | | - for _ in result: |
| 191 | + for _ in _process_on_load_result(result): |
178 | 192 | yield from render_loop(path=ui_request.path, init_request=True) |
179 | 193 | runtime().context().set_previous_node_from_current_node() |
180 | 194 | runtime().context().reset_current_node() |
@@ -277,9 +291,9 @@ def run_page_load(*, path: str): |
277 | 291 | assert page_config and page_config.on_load |
278 | 292 | result = page_config.on_load(LoadEvent(path=path)) |
279 | 293 | # on_load is a generator function then we need to iterate through |
280 | | - # the generator object. |
| 294 | + # the generator object. This also handles async generators and coroutines. |
281 | 295 | if result: |
282 | | - for _ in result: |
| 296 | + for _ in _process_on_load_result(result): |
283 | 297 | yield from render_loop(path=path, init_request=True) |
284 | 298 | runtime().context().set_previous_node_from_current_node() |
285 | 299 | runtime().context().reset_current_node() |
|
0 commit comments