55from inspect import isclass
66from pprint import pformat
77from textwrap import dedent , wrap
8+ from typing import Any , Callable , Self , Type , overload
89
910import numpy as np
10- from astropy .units import Quantity , Unit , UnitConversionError
11+ from astropy .units import Quantity , Unit , UnitBase , UnitConversionError
12+ from numpy .typing import DTypeLike , NDArray
1113
1214log = logging .getLogger (__name__ )
1315
@@ -22,7 +24,7 @@ class FieldValidationError(ValueError):
2224 pass
2325
2426
25- class Field :
27+ class Field [ T ] :
2628 """
2729 Class for storing data in a `Container`.
2830
@@ -55,18 +57,104 @@ class Field:
5557 A callable providing a fresh instance as default value.
5658 """
5759
60+ # only default provided
61+ @overload
62+ def __init__ (
63+ self ,
64+ default : T ,
65+ description : str = "" ,
66+ * ,
67+ unit : None = None ,
68+ ucd : Any = None ,
69+ dtype : None = None ,
70+ type : None = None ,
71+ ndim : None = None ,
72+ allow_none : bool = False ,
73+ max_length : None = None ,
74+ default_factory : None = None ,
75+ ): ...
76+
77+ # only default_factory provided
78+ @overload
79+ def __init__ (
80+ self ,
81+ default : None = None ,
82+ description : str = "" ,
83+ * ,
84+ default_factory : Type [T ] | Callable [[], T ],
85+ unit : None = None ,
86+ ucd : Any = None ,
87+ dtype : None = None ,
88+ type : None = None ,
89+ ndim : None = None ,
90+ allow_none : bool = False ,
91+ max_length : None = None ,
92+ ): ...
93+
94+ # default and type given
95+ @overload
96+ def __init__ [T1 , T2 ](
97+ self : "Field[T1 | T2]" ,
98+ default : T1 ,
99+ description : str = "" ,
100+ * ,
101+ type : Type [T2 ],
102+ unit : None = None ,
103+ ucd : Any = None ,
104+ dtype : None = None ,
105+ ndim : None = None ,
106+ allow_none : bool = False ,
107+ max_length : None = None ,
108+ default_factory : None = None ,
109+ ): ...
110+
111+ # None default but unit provided -> Quantity | None
112+ @overload
113+ def __init__ (
114+ self : "Field[Quantity | None]" ,
115+ default : None ,
116+ description : str = "" ,
117+ * ,
118+ unit : UnitBase ,
119+ type : None = None ,
120+ ucd : Any = None ,
121+ dtype : None = None ,
122+ ndim : None = None ,
123+ allow_none : bool = False ,
124+ max_length : None = None ,
125+ default_factory : None = None ,
126+ ): ...
127+
128+ # array case
129+ @overload
130+ def __init__ (
131+ self : "Field[NDArray | None]" ,
132+ default : None ,
133+ description : str = "" ,
134+ * ,
135+ unit : None = None ,
136+ type : None = None ,
137+ ucd : Any = None ,
138+ dtype : None | DTypeLike = None ,
139+ ndim : None | int = None ,
140+ allow_none : bool = False ,
141+ max_length : None = None ,
142+ default_factory : None = None ,
143+ ): ...
144+
58145 def __init__ (
59146 self ,
60147 default = None ,
61148 description = "" ,
149+ * ,
150+ default_factory : Type [T ] | Callable [[], T ] | None = None ,
62151 unit = None ,
63152 ucd = None ,
64153 dtype = None ,
65154 type = None ,
66155 ndim = None ,
67- allow_none = True ,
68- max_length = None ,
69- default_factory = None ,
156+ allow_none : bool = True ,
157+ max_length : int | None = None ,
70158 ):
71159 self .default = default
72160 self .default_factory = default_factory
@@ -82,6 +170,22 @@ def __init__(
82170 if default_factory is not None and default is not None :
83171 raise ValueError ("Must only provide one of default or default_factory" )
84172
173+ # we only specify the Descriptor protocol __get__ here has it helps type checkers
174+ # and IDEs to provide insights on types of container fields. It is not actually used at runtime
175+ # since the ContainerMeta turns Fields into __slots__ based access to member variables.
176+ # 1. When accessed via the class (e.g., MyContainer.foo), only owner present
177+ @overload
178+ def __get__ (self , instance : None , owner : Any ) -> Self : ...
179+
180+ # 2. access via instance, both arguments present
181+ @overload
182+ def __get__ (self , instance : "Container" , owner : "Type[Container]" ) -> T : ...
183+
184+ def __get__ (
185+ self , instance : "Container | None" , owner : "Type[Container]"
186+ ) -> T | Self :
187+ raise NotImplementedError ("Fields should only be used with Containers" )
188+
85189 def __repr__ (self ):
86190 if self .default_factory is not None :
87191 if isclass (self .default_factory ):
@@ -458,7 +562,7 @@ def validate(self):
458562 )
459563
460564
461- class Map (defaultdict ):
565+ class Map [ K , V ] (defaultdict [ K , V ] ):
462566 """A dictionary of sub-containers that can be added to a Container. This
463567 may be used e.g. to store a set of identical sub-Containers (e.g. indexed
464568 by ``tel_id`` or algorithm name).
0 commit comments