Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changes/2994.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Improve type hints for ``Container`` and ``Field``, IDEs and LSPs should
now better detect the correct types of container fields in code and no longer
complain about things like "Variable of type Field has not attribute length".
11 changes: 11 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def add_reference_type(prefix, objs):
"Sentinel",
"ObserveHandler",
"dict[K, V]",
"T",
"T1",
"T2",
"G",
"K",
"V",
Expand All @@ -144,6 +147,9 @@ def add_reference_type(prefix, objs):
"astropy.coordinates.baseframe.BaseCoordinateFrame",
"astropy.table.table.Table",
"eventio.simtel.simtelfile.SimTelFile",
"ctapipe.core.container.T1",
"ctapipe.core.container.T2",
"DTypeLike",
],
)
nitpick_ignore += add_reference_type(
Expand All @@ -161,6 +167,11 @@ def add_reference_type(prefix, objs):
"-v", # fix for wrong syntax in a traitlets docstring
"cls",
"name",
"T",
"T1",
"T2",
"K",
"V",
],
)
nitpick_ignore += add_reference_type(
Expand Down
116 changes: 110 additions & 6 deletions src/ctapipe/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from inspect import isclass
from pprint import pformat
from textwrap import dedent, wrap
from typing import Any, Callable, Self, Type, overload

import numpy as np
from astropy.units import Quantity, Unit, UnitConversionError
from astropy.units import Quantity, Unit, UnitBase, UnitConversionError
from numpy.typing import DTypeLike, NDArray

log = logging.getLogger(__name__)

Expand All @@ -22,7 +24,7 @@ class FieldValidationError(ValueError):
pass


class Field:
class Field[T]:
"""
Class for storing data in a `Container`.

Expand Down Expand Up @@ -55,18 +57,104 @@ class Field:
A callable providing a fresh instance as default value.
"""

# only default provided
@overload
def __init__(
self,
default: T,
description: str = "",
*,
unit: None = None,
ucd: Any = None,
dtype: None = None,
type: None = None,
ndim: None = None,
allow_none: bool = False,
max_length: None = None,
default_factory: None = None,
): ...

# only default_factory provided
@overload
def __init__(
self,
default: None = None,
description: str = "",
*,
default_factory: Type[T] | Callable[[], T],
unit: None = None,
ucd: Any = None,
dtype: None = None,
type: None = None,
ndim: None = None,
allow_none: bool = False,
max_length: None = None,
): ...

# default and type given
@overload
def __init__[T1, T2](
self: "Field[T1 | T2]",
default: T1,
description: str = "",
*,
type: Type[T2],
unit: None = None,
ucd: Any = None,
dtype: None = None,
ndim: None = None,
allow_none: bool = False,
max_length: None = None,
default_factory: None = None,
): ...

# None default but unit provided -> Quantity | None
@overload
def __init__(
self: "Field[Quantity | None]",
default: None,
description: str = "",
*,
unit: UnitBase,
type: None = None,
ucd: Any = None,
dtype: None = None,
ndim: None = None,
allow_none: bool = False,
max_length: None = None,
default_factory: None = None,
): ...

# array case
@overload
def __init__(
self: "Field[NDArray | None]",
default: None,
description: str = "",
*,
unit: None = None,
type: None = None,
ucd: Any = None,
dtype: None | DTypeLike = None,
ndim: None | int = None,
allow_none: bool = False,
max_length: None = None,
default_factory: None = None,
): ...

def __init__(
self,
default=None,
description="",
*,
default_factory: Type[T] | Callable[[], T] | None = None,
unit=None,
ucd=None,
dtype=None,
type=None,
ndim=None,
allow_none=True,
max_length=None,
default_factory=None,
allow_none: bool = True,
max_length: int | None = None,
):
self.default = default
self.default_factory = default_factory
Expand All @@ -82,6 +170,22 @@ def __init__(
if default_factory is not None and default is not None:
raise ValueError("Must only provide one of default or default_factory")

# we only specify the Descriptor protocol __get__ here has it helps type checkers
# and IDEs to provide insights on types of container fields. It is not actually used at runtime
# since the ContainerMeta turns Fields into __slots__ based access to member variables.
# 1. When accessed via the class (e.g., MyContainer.foo), only owner present
@overload
def __get__(self, instance: None, owner: Any) -> Self: ...

# 2. access via instance, both arguments present
@overload
def __get__(self, instance: "Container", owner: "Type[Container]") -> T: ...

def __get__(
self, instance: "Container | None", owner: "Type[Container]"
) -> T | Self:
raise NotImplementedError("Fields should only be used with Containers")

def __repr__(self):
if self.default_factory is not None:
if isclass(self.default_factory):
Expand Down Expand Up @@ -458,7 +562,7 @@ def validate(self):
)


class Map(defaultdict):
class Map[K, V](defaultdict[K, V]):
"""A dictionary of sub-containers that can be added to a Container. This
may be used e.g. to store a set of identical sub-Containers (e.g. indexed
by ``tel_id`` or algorithm name).
Expand Down
Loading