Skip to content

Commit 8705606

Browse files
authored
Merge pull request #2994 from cta-observatory/container-typing
Add proper typing support for Container fields
2 parents 05555e4 + 16e7a92 commit 8705606

3 files changed

Lines changed: 124 additions & 6 deletions

File tree

docs/changes/2994.feature.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Improve type hints for ``Container`` and ``Field``, IDEs and LSPs should
2+
now better detect the correct types of container fields in code and no longer
3+
complain about things like "Variable of type Field has not attribute length".

docs/conf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def add_reference_type(prefix, objs):
134134
"Sentinel",
135135
"ObserveHandler",
136136
"dict[K, V]",
137+
"T",
138+
"T1",
139+
"T2",
137140
"G",
138141
"K",
139142
"V",
@@ -144,6 +147,9 @@ def add_reference_type(prefix, objs):
144147
"astropy.coordinates.baseframe.BaseCoordinateFrame",
145148
"astropy.table.table.Table",
146149
"eventio.simtel.simtelfile.SimTelFile",
150+
"ctapipe.core.container.T1",
151+
"ctapipe.core.container.T2",
152+
"DTypeLike",
147153
],
148154
)
149155
nitpick_ignore += add_reference_type(
@@ -161,6 +167,11 @@ def add_reference_type(prefix, objs):
161167
"-v", # fix for wrong syntax in a traitlets docstring
162168
"cls",
163169
"name",
170+
"T",
171+
"T1",
172+
"T2",
173+
"K",
174+
"V",
164175
],
165176
)
166177
nitpick_ignore += add_reference_type(

src/ctapipe/core/container.py

Lines changed: 110 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from inspect import isclass
66
from pprint import pformat
77
from textwrap import dedent, wrap
8+
from typing import Any, Callable, Self, Type, overload
89

910
import numpy as np
10-
from astropy.units import Quantity, Unit, UnitConversionError
11+
from astropy.units import Quantity, Unit, UnitBase, UnitConversionError
12+
from numpy.typing import DTypeLike, NDArray
1113

1214
log = logging.getLogger(__name__)
1315

@@ -22,7 +24,7 @@ class FieldValidationError(ValueError):
2224
pass
2325

2426

25-
class Field:
27+
class Field[T]:
2628
"""
2729
Class for storing data in a `Container`.
2830
@@ -55,18 +57,104 @@ class Field:
5557
A callable providing a fresh instance as default value.
5658
"""
5759

60+
# only default provided
61+
@overload
62+
def __init__(
63+
self,
64+
default: T,
65+
description: str = "",
66+
*,
67+
unit: None = None,
68+
ucd: Any = None,
69+
dtype: None = None,
70+
type: None = None,
71+
ndim: None = None,
72+
allow_none: bool = False,
73+
max_length: None = None,
74+
default_factory: None = None,
75+
): ...
76+
77+
# only default_factory provided
78+
@overload
79+
def __init__(
80+
self,
81+
default: None = None,
82+
description: str = "",
83+
*,
84+
default_factory: Type[T] | Callable[[], T],
85+
unit: None = None,
86+
ucd: Any = None,
87+
dtype: None = None,
88+
type: None = None,
89+
ndim: None = None,
90+
allow_none: bool = False,
91+
max_length: None = None,
92+
): ...
93+
94+
# default and type given
95+
@overload
96+
def __init__[T1, T2](
97+
self: "Field[T1 | T2]",
98+
default: T1,
99+
description: str = "",
100+
*,
101+
type: Type[T2],
102+
unit: None = None,
103+
ucd: Any = None,
104+
dtype: None = None,
105+
ndim: None = None,
106+
allow_none: bool = False,
107+
max_length: None = None,
108+
default_factory: None = None,
109+
): ...
110+
111+
# None default but unit provided -> Quantity | None
112+
@overload
113+
def __init__(
114+
self: "Field[Quantity | None]",
115+
default: None,
116+
description: str = "",
117+
*,
118+
unit: UnitBase,
119+
type: None = None,
120+
ucd: Any = None,
121+
dtype: None = None,
122+
ndim: None = None,
123+
allow_none: bool = False,
124+
max_length: None = None,
125+
default_factory: None = None,
126+
): ...
127+
128+
# array case
129+
@overload
130+
def __init__(
131+
self: "Field[NDArray | None]",
132+
default: None,
133+
description: str = "",
134+
*,
135+
unit: None = None,
136+
type: None = None,
137+
ucd: Any = None,
138+
dtype: None | DTypeLike = None,
139+
ndim: None | int = None,
140+
allow_none: bool = False,
141+
max_length: None = None,
142+
default_factory: None = None,
143+
): ...
144+
58145
def __init__(
59146
self,
60147
default=None,
61148
description="",
149+
*,
150+
default_factory: Type[T] | Callable[[], T] | None = None,
62151
unit=None,
63152
ucd=None,
64153
dtype=None,
65154
type=None,
66155
ndim=None,
67-
allow_none=True,
68-
max_length=None,
69-
default_factory=None,
156+
allow_none: bool = True,
157+
max_length: int | None = None,
70158
):
71159
self.default = default
72160
self.default_factory = default_factory
@@ -82,6 +170,22 @@ def __init__(
82170
if default_factory is not None and default is not None:
83171
raise ValueError("Must only provide one of default or default_factory")
84172

173+
# we only specify the Descriptor protocol __get__ here has it helps type checkers
174+
# and IDEs to provide insights on types of container fields. It is not actually used at runtime
175+
# since the ContainerMeta turns Fields into __slots__ based access to member variables.
176+
# 1. When accessed via the class (e.g., MyContainer.foo), only owner present
177+
@overload
178+
def __get__(self, instance: None, owner: Any) -> Self: ...
179+
180+
# 2. access via instance, both arguments present
181+
@overload
182+
def __get__(self, instance: "Container", owner: "Type[Container]") -> T: ...
183+
184+
def __get__(
185+
self, instance: "Container | None", owner: "Type[Container]"
186+
) -> T | Self:
187+
raise NotImplementedError("Fields should only be used with Containers")
188+
85189
def __repr__(self):
86190
if self.default_factory is not None:
87191
if isclass(self.default_factory):
@@ -458,7 +562,7 @@ def validate(self):
458562
)
459563

460564

461-
class Map(defaultdict):
565+
class Map[K, V](defaultdict[K, V]):
462566
"""A dictionary of sub-containers that can be added to a Container. This
463567
may be used e.g. to store a set of identical sub-Containers (e.g. indexed
464568
by ``tel_id`` or algorithm name).

0 commit comments

Comments
 (0)