From c39fb9e5ee983599e3fe2f72e5369da21bf0b8a4 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 27 Mar 2025 10:22:42 +0000 Subject: [PATCH] Revert "feat: Add UnionType support to query method" --- src/textual/dom.py | 47 +++------------------------- tests/test_dom.py | 78 ---------------------------------------------- 2 files changed, 5 insertions(+), 120 deletions(-) diff --git a/src/textual/dom.py b/src/textual/dom.py index 893b8fa98a..9587bca586 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -24,13 +24,6 @@ overload, ) -try: - from types import UnionType - from typing import get_args -except ImportError: - UnionType = None # Type will not exist in earlier versions - get_args = None # Not needed for earlier versions - import rich.repr from rich.highlighter import ReprHighlighter from rich.style import Style @@ -102,9 +95,9 @@ def check_identifiers(description: str, *names: str) -> None: description: Description of where identifier is used for error message. *names: Identifiers to check. """ - fullmatch = _re_identifier.fullmatch + match = _re_identifier.fullmatch for name in names: - if fullmatch(name) is None: + if match(name) is None: raise BadIdentifier( f"{name!r} is an invalid {description}; " "identifiers must contain only letters, numbers, underscores, or hyphens, and must not begin with a number." @@ -1373,52 +1366,22 @@ def query(self, selector: str | None = None) -> DOMQuery[Widget]: ... @overload def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ... - if UnionType is not None: - - @overload - def query(self, selector: UnionType) -> DOMQuery[Widget]: ... - def query( - self, selector: str | type[QueryType] | UnionType | None = None + self, selector: str | type[QueryType] | None = None ) -> DOMQuery[Widget] | DOMQuery[QueryType]: - """Query the DOM for children that match a selector or widget type, or a union of widget types. + """Query the DOM for children that match a selector or widget type. Args: - selector: A CSS selector, widget type, a union of widget types, or `None` for all nodes. + selector: A CSS selector, widget type, or `None` for all nodes. Returns: A query object. - - Raises: - TypeError: If any type in a Union is not a Widget subclass. """ from textual.css.query import DOMQuery, QueryType from textual.widget import Widget if isinstance(selector, str) or selector is None: return DOMQuery[Widget](self, filter=selector) - elif UnionType is not None and isinstance(selector, UnionType): - # Get all types from the union, including nested unions - def get_all_types(union_type): - types = set() - for t in get_args(union_type): - if isinstance(t, UnionType): - types.update(get_all_types(t)) - else: - types.add(t) - return types - - # Validate all types in the union are Widget subclasses - types_in_union = get_args(selector) - if not all( - isinstance(t, type) and issubclass(t, Widget) for t in types_in_union - ): - raise TypeError("All types in Union must be Widget subclasses") - - # Convert Union type to comma-separated string of class names - type_names = [t.__name__ for t in types_in_union] - selector_str = ", ".join(type_names) - return DOMQuery[Widget](self, filter=selector_str) else: return DOMQuery[QueryType](self, filter=selector.__name__) diff --git a/tests/test_dom.py b/tests/test_dom.py index 6e93d8bedf..89bbe53c43 100644 --- a/tests/test_dom.py +++ b/tests/test_dom.py @@ -1,10 +1,7 @@ import pytest -from textual.app import App from textual.css.errors import StyleValueError from textual.dom import BadIdentifier, DOMNode -from textual.widget import Widget -from textual.widgets import Input, Select, Static def test_display_default(): @@ -283,78 +280,3 @@ def test_id_validation(identifier: str): """Regression tests for https://github.com/Textualize/textual/issues/3954.""" with pytest.raises(BadIdentifier): DOMNode(id=identifier) - - -class SimpleApp(App): - def compose(self): - yield Input(id="input1") - yield Select([], id="select1") - yield Static("Hello", id="static1") - yield Input(id="input2") - - -async def test_query_union_type(): - # Test with a UnionType - simple_app = SimpleApp() - async with simple_app.run_test(): - results = simple_app.query(Input | Select) - assert len(results) == 3 - assert {w.id for w in results} == {"input1", "select1", "input2"} - - # Test with a single type - results2 = simple_app.query(Input) - assert len(results2) == 2 - assert {w.id for w in results2} == {"input1", "input2"} - - # Test with string selector - results3 = simple_app.query("#input1") - assert len(results3) == 1 - assert results3[0].id == "input1" - - -async def test_query_nested_unions(): - """Test handling of nested unions.""" - - simple_app = SimpleApp() - async with simple_app.run_test(): - # Create nested union types - InputOrSelect = Input | Select - InputSelectOrStatic = InputOrSelect | Static - - # Test nested union query - results = simple_app.query(InputSelectOrStatic) - - # Verify that we find all our explicitly defined widgets - widget_ids = {w.id for w in results if w.id is not None} - expected_ids = {"input1", "select1", "static1", "input2"} - assert expected_ids.issubset(widget_ids), "Not all expected widgets were found" - - # Verify we get the right types of widgets - assert all( - isinstance(w, (Input, Select, Static)) for w in results - ), "Found unexpected widget types" - - # Verify each expected widget appears exactly once - for expected_id in expected_ids: - matching_widgets = [w for w in results if w.id == expected_id] - assert ( - len(matching_widgets) == 1 - ), f"Widget with id {expected_id} should appear exactly once" - - -async def test_query_empty_union(): - """Test querying with empty or invalid unions.""" - - class AnotherWidget(Widget): - pass - - simple_app = SimpleApp() - async with simple_app.run_test(): - - # Test with a type that exists but has no matches - results = simple_app.query(AnotherWidget) - assert len(results) == 0 - - # Test with widget union that has no matches - results = simple_app.query(AnotherWidget | AnotherWidget) - assert len(results) == 0