Skip to content

Revert "feat: Add UnionType support to query method" #5679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 5 additions & 42 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@
overload,
)

try:
from types import UnionType
from typing import get_args
except ImportError:
UnionType = None # Type will not exist in earlier versions
get_args = None # Not needed for earlier versions

import rich.repr
from rich.highlighter import ReprHighlighter
from rich.style import Style
Expand Down Expand Up @@ -102,9 +95,9 @@ def check_identifiers(description: str, *names: str) -> None:
description: Description of where identifier is used for error message.
*names: Identifiers to check.
"""
fullmatch = _re_identifier.fullmatch
match = _re_identifier.fullmatch
for name in names:
if fullmatch(name) is None:
if match(name) is None:
raise BadIdentifier(
f"{name!r} is an invalid {description}; "
"identifiers must contain only letters, numbers, underscores, or hyphens, and must not begin with a number."
Expand Down Expand Up @@ -1373,52 +1366,22 @@ def query(self, selector: str | None = None) -> DOMQuery[Widget]: ...
@overload
def query(self, selector: type[QueryType]) -> DOMQuery[QueryType]: ...

if UnionType is not None:

@overload
def query(self, selector: UnionType) -> DOMQuery[Widget]: ...

def query(
self, selector: str | type[QueryType] | UnionType | None = None
self, selector: str | type[QueryType] | None = None
) -> DOMQuery[Widget] | DOMQuery[QueryType]:
"""Query the DOM for children that match a selector or widget type, or a union of widget types.
"""Query the DOM for children that match a selector or widget type.

Args:
selector: A CSS selector, widget type, a union of widget types, or `None` for all nodes.
selector: A CSS selector, widget type, or `None` for all nodes.

Returns:
A query object.

Raises:
TypeError: If any type in a Union is not a Widget subclass.
"""
from textual.css.query import DOMQuery, QueryType
from textual.widget import Widget

if isinstance(selector, str) or selector is None:
return DOMQuery[Widget](self, filter=selector)
elif UnionType is not None and isinstance(selector, UnionType):
# Get all types from the union, including nested unions
def get_all_types(union_type):
types = set()
for t in get_args(union_type):
if isinstance(t, UnionType):
types.update(get_all_types(t))
else:
types.add(t)
return types

# Validate all types in the union are Widget subclasses
types_in_union = get_args(selector)
if not all(
isinstance(t, type) and issubclass(t, Widget) for t in types_in_union
):
raise TypeError("All types in Union must be Widget subclasses")

# Convert Union type to comma-separated string of class names
type_names = [t.__name__ for t in types_in_union]
selector_str = ", ".join(type_names)
return DOMQuery[Widget](self, filter=selector_str)
else:
return DOMQuery[QueryType](self, filter=selector.__name__)

Expand Down
78 changes: 0 additions & 78 deletions tests/test_dom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest

from textual.app import App
from textual.css.errors import StyleValueError
from textual.dom import BadIdentifier, DOMNode
from textual.widget import Widget
from textual.widgets import Input, Select, Static


def test_display_default():
Expand Down Expand Up @@ -283,78 +280,3 @@ def test_id_validation(identifier: str):
"""Regression tests for https://github.com/Textualize/textual/issues/3954."""
with pytest.raises(BadIdentifier):
DOMNode(id=identifier)


class SimpleApp(App):
def compose(self):
yield Input(id="input1")
yield Select([], id="select1")
yield Static("Hello", id="static1")
yield Input(id="input2")


async def test_query_union_type():
# Test with a UnionType
simple_app = SimpleApp()
async with simple_app.run_test():
results = simple_app.query(Input | Select)
assert len(results) == 3
assert {w.id for w in results} == {"input1", "select1", "input2"}

# Test with a single type
results2 = simple_app.query(Input)
assert len(results2) == 2
assert {w.id for w in results2} == {"input1", "input2"}

# Test with string selector
results3 = simple_app.query("#input1")
assert len(results3) == 1
assert results3[0].id == "input1"


async def test_query_nested_unions():
"""Test handling of nested unions."""

simple_app = SimpleApp()
async with simple_app.run_test():
# Create nested union types
InputOrSelect = Input | Select
InputSelectOrStatic = InputOrSelect | Static

# Test nested union query
results = simple_app.query(InputSelectOrStatic)

# Verify that we find all our explicitly defined widgets
widget_ids = {w.id for w in results if w.id is not None}
expected_ids = {"input1", "select1", "static1", "input2"}
assert expected_ids.issubset(widget_ids), "Not all expected widgets were found"

# Verify we get the right types of widgets
assert all(
isinstance(w, (Input, Select, Static)) for w in results
), "Found unexpected widget types"

# Verify each expected widget appears exactly once
for expected_id in expected_ids:
matching_widgets = [w for w in results if w.id == expected_id]
assert (
len(matching_widgets) == 1
), f"Widget with id {expected_id} should appear exactly once"


async def test_query_empty_union():
"""Test querying with empty or invalid unions."""

class AnotherWidget(Widget):
pass

simple_app = SimpleApp()
async with simple_app.run_test():

# Test with a type that exists but has no matches
results = simple_app.query(AnotherWidget)
assert len(results) == 0

# Test with widget union that has no matches
results = simple_app.query(AnotherWidget | AnotherWidget)
assert len(results) == 0
Loading