Skip to content

Commit 77e7f98

Browse files
authored
define state as a generic named tuple (#827)
* define state as a generic named tuple * changelog entry
1 parent 8ff09d2 commit 77e7f98

File tree

5 files changed

+100
-59
lines changed

5 files changed

+100
-59
lines changed

Diff for: docs/source/about/changelog.rst

+22
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ Unreleased
4141
``idom-router``, IDOM's server routes will always take priority.
4242
- :pull:`824` - Backend implementations now strip any URL prefix in the pathname for
4343
``use_location``.
44+
- :pull:`827` - ``use_state`` now returns a named tuple with ``value`` and ``set_value``
45+
fields. This is convenient for adding type annotations if the initial state value is
46+
not the same as the values you might pass to the state setter. Where previously you
47+
might have to do something like:
48+
49+
.. code-block::
50+
51+
value: int | None = None
52+
value, set_value = use_state(value)
53+
54+
Now you can annotate your state using the ``State`` class:
55+
56+
.. code-block::
57+
58+
state: State[int | None] = use_state(None)
59+
60+
# access value and setter
61+
state.value
62+
state.set_value
63+
64+
# can still destructure if you need to
65+
value, set_value = state
4466
4567
**Added**
4668

Diff for: src/idom/core/hooks.py

+42-59
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from idom.utils import Ref
2727

2828
from ._thread_local import ThreadLocal
29-
from .types import ComponentType, Key, VdomDict
29+
from .types import ComponentType, Key, State, VdomDict
3030
from .vdom import vdom
3131

3232

@@ -46,35 +46,20 @@
4646

4747
logger = getLogger(__name__)
4848

49-
_StateType = TypeVar("_StateType")
49+
_Type = TypeVar("_Type")
5050

5151

5252
@overload
53-
def use_state(
54-
initial_value: Callable[[], _StateType],
55-
) -> Tuple[
56-
_StateType,
57-
Callable[[_StateType | Callable[[_StateType], _StateType]], None],
58-
]:
53+
def use_state(initial_value: Callable[[], _Type]) -> State[_Type]:
5954
...
6055

6156

6257
@overload
63-
def use_state(
64-
initial_value: _StateType,
65-
) -> Tuple[
66-
_StateType,
67-
Callable[[_StateType | Callable[[_StateType], _StateType]], None],
68-
]:
58+
def use_state(initial_value: _Type) -> State[_Type]:
6959
...
7060

7161

72-
def use_state(
73-
initial_value: _StateType | Callable[[], _StateType],
74-
) -> Tuple[
75-
_StateType,
76-
Callable[[_StateType | Callable[[_StateType], _StateType]], None],
77-
]:
62+
def use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]:
7863
"""See the full :ref:`Use State` docs for details
7964
8065
Parameters:
@@ -87,16 +72,16 @@ def use_state(
8772
A tuple containing the current state and a function to update it.
8873
"""
8974
current_state = _use_const(lambda: _CurrentState(initial_value))
90-
return current_state.value, current_state.dispatch
75+
return State(current_state.value, current_state.dispatch)
9176

9277

93-
class _CurrentState(Generic[_StateType]):
78+
class _CurrentState(Generic[_Type]):
9479

9580
__slots__ = "value", "dispatch"
9681

9782
def __init__(
9883
self,
99-
initial_value: Union[_StateType, Callable[[], _StateType]],
84+
initial_value: Union[_Type, Callable[[], _Type]],
10085
) -> None:
10186
if callable(initial_value):
10287
self.value = initial_value()
@@ -105,9 +90,7 @@ def __init__(
10590

10691
hook = current_hook()
10792

108-
def dispatch(
109-
new: Union[_StateType, Callable[[_StateType], _StateType]]
110-
) -> None:
93+
def dispatch(new: Union[_Type, Callable[[_Type], _Type]]) -> None:
11194
if callable(new):
11295
next_value = new(self.value)
11396
else:
@@ -234,14 +217,14 @@ def use_debug_value(
234217
logger.debug(f"{current_hook().component} {new}")
235218

236219

237-
def create_context(default_value: _StateType) -> Context[_StateType]:
220+
def create_context(default_value: _Type) -> Context[_Type]:
238221
"""Return a new context type for use in :func:`use_context`"""
239222

240223
def context(
241224
*children: Any,
242-
value: _StateType = default_value,
225+
value: _Type = default_value,
243226
key: Key | None = None,
244-
) -> ContextProvider[_StateType]:
227+
) -> ContextProvider[_Type]:
245228
return ContextProvider(
246229
*children,
247230
value=value,
@@ -254,19 +237,19 @@ def context(
254237
return context
255238

256239

257-
class Context(Protocol[_StateType]):
240+
class Context(Protocol[_Type]):
258241
"""Returns a :class:`ContextProvider` component"""
259242

260243
def __call__(
261244
self,
262245
*children: Any,
263-
value: _StateType = ...,
246+
value: _Type = ...,
264247
key: Key | None = ...,
265-
) -> ContextProvider[_StateType]:
248+
) -> ContextProvider[_Type]:
266249
...
267250

268251

269-
def use_context(context: Context[_StateType]) -> _StateType:
252+
def use_context(context: Context[_Type]) -> _Type:
270253
"""Get the current value for the given context type.
271254
272255
See the full :ref:`Use Context` docs for more information.
@@ -282,7 +265,7 @@ def use_context(context: Context[_StateType]) -> _StateType:
282265
# lastly check that 'value' kwarg exists
283266
assert "value" in context.__kwdefaults__, f"{context} has no 'value' kwarg"
284267
# then we can safely access the context's default value
285-
return cast(_StateType, context.__kwdefaults__["value"])
268+
return cast(_Type, context.__kwdefaults__["value"])
286269

287270
subscribers = provider._subscribers
288271

@@ -294,13 +277,13 @@ def subscribe_to_context_change() -> Callable[[], None]:
294277
return provider._value
295278

296279

297-
class ContextProvider(Generic[_StateType]):
280+
class ContextProvider(Generic[_Type]):
298281
def __init__(
299282
self,
300283
*children: Any,
301-
value: _StateType,
284+
value: _Type,
302285
key: Key | None,
303-
type: Context[_StateType],
286+
type: Context[_Type],
304287
) -> None:
305288
self.children = children
306289
self.key = key
@@ -312,7 +295,7 @@ def render(self) -> VdomDict:
312295
current_hook().set_context_provider(self)
313296
return vdom("", *self.children)
314297

315-
def should_render(self, new: ContextProvider[_StateType]) -> bool:
298+
def should_render(self, new: ContextProvider[_Type]) -> bool:
316299
if not strictly_equal(self._value, new._value):
317300
for hook in self._subscribers:
318301
hook.set_context_provider(new)
@@ -328,9 +311,9 @@ def __repr__(self) -> str:
328311

329312

330313
def use_reducer(
331-
reducer: Callable[[_StateType, _ActionType], _StateType],
332-
initial_value: _StateType,
333-
) -> Tuple[_StateType, Callable[[_ActionType], None]]:
314+
reducer: Callable[[_Type, _ActionType], _Type],
315+
initial_value: _Type,
316+
) -> Tuple[_Type, Callable[[_ActionType], None]]:
334317
"""See the full :ref:`Use Reducer` docs for details
335318
336319
Parameters:
@@ -348,8 +331,8 @@ def use_reducer(
348331

349332

350333
def _create_dispatcher(
351-
reducer: Callable[[_StateType, _ActionType], _StateType],
352-
set_state: Callable[[Callable[[_StateType], _StateType]], None],
334+
reducer: Callable[[_Type, _ActionType], _Type],
335+
set_state: Callable[[Callable[[_Type], _Type]], None],
353336
) -> Callable[[_ActionType], None]:
354337
def dispatch(action: _ActionType) -> None:
355338
set_state(lambda last_state: reducer(last_state, action))
@@ -409,7 +392,7 @@ def setup(function: _CallbackFunc) -> _CallbackFunc:
409392
class _LambdaCaller(Protocol):
410393
"""MyPy doesn't know how to deal with TypeVars only used in function return"""
411394

412-
def __call__(self, func: Callable[[], _StateType]) -> _StateType:
395+
def __call__(self, func: Callable[[], _Type]) -> _Type:
413396
...
414397

415398

@@ -423,16 +406,16 @@ def use_memo(
423406

424407
@overload
425408
def use_memo(
426-
function: Callable[[], _StateType],
409+
function: Callable[[], _Type],
427410
dependencies: Sequence[Any] | ellipsis | None = ...,
428-
) -> _StateType:
411+
) -> _Type:
429412
...
430413

431414

432415
def use_memo(
433-
function: Optional[Callable[[], _StateType]] = None,
416+
function: Optional[Callable[[], _Type]] = None,
434417
dependencies: Sequence[Any] | ellipsis | None = ...,
435-
) -> Union[_StateType, Callable[[Callable[[], _StateType]], _StateType]]:
418+
) -> Union[_Type, Callable[[Callable[[], _Type]], _Type]]:
436419
"""See the full :ref:`Use Memo` docs for details
437420
438421
Parameters:
@@ -449,7 +432,7 @@ def use_memo(
449432
"""
450433
dependencies = _try_to_infer_closure_values(function, dependencies)
451434

452-
memo: _Memo[_StateType] = _use_const(_Memo)
435+
memo: _Memo[_Type] = _use_const(_Memo)
453436

454437
if memo.empty():
455438
# we need to initialize on the first run
@@ -471,17 +454,17 @@ def use_memo(
471454
else:
472455
changed = False
473456

474-
setup: Callable[[Callable[[], _StateType]], _StateType]
457+
setup: Callable[[Callable[[], _Type]], _Type]
475458

476459
if changed:
477460

478-
def setup(function: Callable[[], _StateType]) -> _StateType:
461+
def setup(function: Callable[[], _Type]) -> _Type:
479462
current_value = memo.value = function()
480463
return current_value
481464

482465
else:
483466

484-
def setup(function: Callable[[], _StateType]) -> _StateType:
467+
def setup(function: Callable[[], _Type]) -> _Type:
485468
return memo.value
486469

487470
if function is not None:
@@ -490,12 +473,12 @@ def setup(function: Callable[[], _StateType]) -> _StateType:
490473
return setup
491474

492475

493-
class _Memo(Generic[_StateType]):
476+
class _Memo(Generic[_Type]):
494477
"""Simple object for storing memoization data"""
495478

496479
__slots__ = "value", "deps"
497480

498-
value: _StateType
481+
value: _Type
499482
deps: Sequence[Any]
500483

501484
def empty(self) -> bool:
@@ -507,7 +490,7 @@ def empty(self) -> bool:
507490
return False
508491

509492

510-
def use_ref(initial_value: _StateType) -> Ref[_StateType]:
493+
def use_ref(initial_value: _Type) -> Ref[_Type]:
511494
"""See the full :ref:`Use State` docs for details
512495
513496
Parameters:
@@ -519,7 +502,7 @@ def use_ref(initial_value: _StateType) -> Ref[_StateType]:
519502
return _use_const(lambda: Ref(initial_value))
520503

521504

522-
def _use_const(function: Callable[[], _StateType]) -> _StateType:
505+
def _use_const(function: Callable[[], _Type]) -> _Type:
523506
return current_hook().use_state(function)
524507

525508

@@ -670,7 +653,7 @@ def schedule_render(self) -> None:
670653
self._schedule_render()
671654
return None
672655

673-
def use_state(self, function: Callable[[], _StateType]) -> _StateType:
656+
def use_state(self, function: Callable[[], _Type]) -> _Type:
674657
if not self._rendered_atleast_once:
675658
# since we're not intialized yet we're just appending state
676659
result = function()
@@ -689,8 +672,8 @@ def set_context_provider(self, provider: ContextProvider[Any]) -> None:
689672
self._context_providers[provider.type] = provider
690673

691674
def get_context_provider(
692-
self, context: Context[_StateType]
693-
) -> ContextProvider[_StateType] | None:
675+
self, context: Context[_Type]
676+
) -> ContextProvider[_Type] | None:
694677
return self._context_providers.get(context)
695678

696679
def affect_component_will_render(self, component: ComponentType) -> None:

Diff for: src/idom/core/types.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from __future__ import annotations
22

3+
import sys
4+
from collections import namedtuple
35
from types import TracebackType
46
from typing import (
7+
TYPE_CHECKING,
58
Any,
69
Callable,
710
Dict,
11+
Generic,
812
Iterable,
913
List,
1014
Mapping,
15+
NamedTuple,
1116
Optional,
1217
Sequence,
1318
Type,
@@ -18,6 +23,19 @@
1823
from typing_extensions import Protocol, TypedDict, runtime_checkable
1924

2025

26+
_Type = TypeVar("_Type")
27+
28+
29+
if TYPE_CHECKING or sys.version_info < (3, 9) or sys.version_info >= (3, 11):
30+
31+
class State(NamedTuple, Generic[_Type]): # pragma: no cover
32+
value: _Type
33+
set_value: Callable[[_Type | Callable[[_Type], _Type]], None]
34+
35+
else:
36+
State = namedtuple("State", ("value", "set_value"))
37+
38+
2139
ComponentConstructor = Callable[..., "ComponentType"]
2240
"""Simple function returning a new component"""
2341

Diff for: src/idom/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Key,
1919
LayoutType,
2020
RootComponentConstructor,
21+
State,
2122
VdomAttributes,
2223
VdomAttributesAndChildren,
2324
VdomChild,
@@ -43,6 +44,7 @@
4344
"LayoutType",
4445
"Location",
4546
"RootComponentConstructor",
47+
"State",
4648
"VdomAttributes",
4749
"VdomAttributesAndChildren",
4850
"VdomChild",

Diff for: tests/test_core/test_hooks.py

+16
Original file line numberDiff line numberDiff line change
@@ -1379,3 +1379,19 @@ def InnerComponent():
13791379
hook.latest.schedule_render()
13801380
await layout.render()
13811381
assert inner_render_count.current == 1
1382+
1383+
1384+
async def test_use_state_named_tuple():
1385+
state = idom.Ref()
1386+
1387+
@idom.component
1388+
def some_component():
1389+
state.current = idom.use_state(1)
1390+
return None
1391+
1392+
async with idom.Layout(some_component()) as layout:
1393+
await layout.render()
1394+
assert state.current.value == 1
1395+
state.current.set_value(2)
1396+
await layout.render()
1397+
assert state.current.value == 2

0 commit comments

Comments
 (0)