Skip to content

Commit 617e976

Browse files
richard-toclaude
andauthored
Fix on_load event to support async functions (#1300) (#1323)
Co-authored-by: Claude <[email protected]>
1 parent 7fd0c83 commit 617e976

File tree

7 files changed

+147
-31
lines changed

7 files changed

+147
-31
lines changed

mesop/examples/testing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from mesop.examples.testing import (
2+
async_onload as async_onload,
3+
)
14
from mesop.examples.testing import (
25
click_is_target as click_is_target,
36
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Test script to verify async on_load support.
3+
This reproduces the issue from GitHub issue #1300.
4+
"""
5+
6+
import asyncio
7+
8+
import mesop as me
9+
10+
11+
@me.stateclass
12+
class State:
13+
message: str = ""
14+
15+
16+
async def async_on_load(e: me.LoadEvent):
17+
"""Async generator on_load handler."""
18+
state = me.state(State)
19+
state.message = "Loading..."
20+
yield
21+
22+
# Simulate async operation
23+
await asyncio.sleep(0.1)
24+
25+
state.message = "Loaded from async generator!"
26+
yield
27+
28+
29+
async def async_coroutine_on_load(e: me.LoadEvent):
30+
"""Async coroutine on_load handler (no yield)."""
31+
state = me.state(State)
32+
await asyncio.sleep(0.1)
33+
state.message = "Loaded from coroutine!"
34+
35+
36+
@me.page(path="/async_gen", on_load=async_on_load)
37+
def page_async_gen():
38+
state = me.state(State)
39+
me.text(f"Async Generator: {state.message}")
40+
41+
42+
@me.page(path="/async_coro", on_load=async_coroutine_on_load)
43+
def page_async_coro():
44+
state = me.state(State)
45+
me.text(f"Async Coroutine: {state.message}")

mesop/runtime/context.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import copy
32
import threading
43
import types
@@ -18,6 +17,7 @@
1817
MesopException,
1918
)
2019
from mesop.server.state_session import state_session
20+
from mesop.utils.async_utils import run_async_generator, run_coroutine
2121

2222
T = TypeVar("T")
2323

@@ -363,9 +363,9 @@ def run_event_handler(
363363
result = handler(payload)
364364
if result is not None:
365365
if isinstance(result, types.AsyncGeneratorType):
366-
yield from _run_async_generator(result)
366+
yield from run_async_generator(result)
367367
elif isinstance(result, types.CoroutineType):
368-
yield _run_coroutine(result)
368+
yield run_coroutine(result)
369369
else:
370370
yield from result
371371
else:
@@ -374,26 +374,3 @@ def run_event_handler(
374374
raise MesopException(
375375
f"Unknown handler id: {event.handler_id} from event {event}"
376376
)
377-
378-
379-
def _run_async_generator(agen: types.AsyncGeneratorType[None, None]):
380-
loop = _get_or_create_event_loop()
381-
try:
382-
while True:
383-
yield loop.run_until_complete(agen.__anext__())
384-
except StopAsyncIteration:
385-
pass
386-
387-
388-
def _run_coroutine(coroutine: types.CoroutineType):
389-
loop = _get_or_create_event_loop()
390-
return loop.run_until_complete(coroutine)
391-
392-
393-
def _get_or_create_event_loop():
394-
try:
395-
return asyncio.get_running_loop()
396-
except RuntimeError:
397-
loop = asyncio.new_event_loop()
398-
asyncio.set_event_loop(loop)
399-
return loop

mesop/runtime/runtime.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import AsyncGenerator, Coroutine
12
from copy import deepcopy
23
from dataclasses import dataclass
34
from typing import Any, Callable, Generator, Type, TypeVar, cast
@@ -23,7 +24,13 @@ class EmptyState:
2324
pass
2425

2526

26-
OnLoadHandler = Callable[[LoadEvent], None | Generator[None, None, None]]
27+
OnLoadHandler = Callable[
28+
[LoadEvent],
29+
None
30+
| Generator[None, None, None]
31+
| AsyncGenerator[None, None]
32+
| Coroutine[Any, Any, None],
33+
]
2734

2835

2936
@dataclass(kw_only=True)

mesop/server/server.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import secrets
44
import threading
5+
import types
56
from typing import Generator, Sequence
67

78
from flask import (
@@ -36,6 +37,7 @@
3637
prefix_base_url,
3738
serialize,
3839
)
40+
from mesop.utils.async_utils import run_async_generator, run_coroutine
3941
from mesop.utils.url_utils import remove_url_query_param
4042
from mesop.warn import warn
4143

@@ -44,6 +46,18 @@
4446
logger = logging.getLogger(__name__)
4547

4648

49+
def _process_on_load_result(result) -> Generator[None, None, None]:
50+
"""Process on_load result, handling sync generators, async generators, and coroutines."""
51+
if result is not None:
52+
if isinstance(result, types.AsyncGeneratorType):
53+
yield from run_async_generator(result)
54+
elif isinstance(result, types.CoroutineType):
55+
yield run_coroutine(result)
56+
else:
57+
# Regular generator
58+
yield from result
59+
60+
4761
def configure_flask_app(
4862
*, prod_mode: bool = True, exceptions_to_propagate: Sequence[type] = ()
4963
) -> Flask:
@@ -172,9 +186,9 @@ def generate_data(ui_request: pb.UiRequest) -> Generator[str, None, None]:
172186
)
173187
)
174188
# on_load is a generator function then we need to iterate through
175-
# the generator object.
189+
# the generator object. This also handles async generators and coroutines.
176190
if result:
177-
for _ in result:
191+
for _ in _process_on_load_result(result):
178192
yield from render_loop(path=ui_request.path, init_request=True)
179193
runtime().context().set_previous_node_from_current_node()
180194
runtime().context().reset_current_node()
@@ -277,9 +291,9 @@ def run_page_load(*, path: str):
277291
assert page_config and page_config.on_load
278292
result = page_config.on_load(LoadEvent(path=path))
279293
# on_load is a generator function then we need to iterate through
280-
# the generator object.
294+
# the generator object. This also handles async generators and coroutines.
281295
if result:
282-
for _ in result:
296+
for _ in _process_on_load_result(result):
283297
yield from render_loop(path=path, init_request=True)
284298
runtime().context().set_previous_node_from_current_node()
285299
runtime().context().reset_current_node()
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import {test, expect} from '@playwright/test';
2+
3+
test.describe('async onload events', () => {
4+
test('async generator', async ({page}) => {
5+
await page.goto('/async_gen');
6+
7+
await expect(
8+
page.getByText('Async Generator: Loaded from async generator!'),
9+
).toBeVisible();
10+
});
11+
12+
test('async coroutine', async ({page}) => {
13+
await page.goto('/async_coro');
14+
15+
await expect(
16+
page.getByText('Async Coroutine: Loaded from coroutine!'),
17+
).toBeVisible();
18+
});
19+
});

mesop/utils/async_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Utilities for handling async generators and coroutines in Mesop event handlers."""
2+
3+
import asyncio
4+
import types
5+
from typing import Generator
6+
7+
8+
def run_async_generator(
9+
agen: types.AsyncGeneratorType[None, None],
10+
) -> Generator[None, None, None]:
11+
"""Run an async generator by iterating through it using an event loop.
12+
13+
Args:
14+
agen: The async generator to run
15+
16+
Yields:
17+
None for each iteration of the async generator
18+
"""
19+
loop = get_or_create_event_loop()
20+
try:
21+
while True:
22+
yield loop.run_until_complete(agen.__anext__())
23+
except StopAsyncIteration:
24+
pass
25+
26+
27+
def run_coroutine(coroutine: types.CoroutineType):
28+
"""Run a coroutine using an event loop and return its result.
29+
30+
Args:
31+
coroutine: The coroutine to run
32+
33+
Returns:
34+
The result of the coroutine
35+
"""
36+
loop = get_or_create_event_loop()
37+
return loop.run_until_complete(coroutine)
38+
39+
40+
def get_or_create_event_loop():
41+
"""Get the current event loop or create a new one if none exists.
42+
43+
Returns:
44+
The asyncio event loop
45+
"""
46+
try:
47+
return asyncio.get_running_loop()
48+
except RuntimeError:
49+
loop = asyncio.new_event_loop()
50+
asyncio.set_event_loop(loop)
51+
return loop

0 commit comments

Comments
 (0)