Skip to content

Commit ce185c2

Browse files
committed
More typing (draft)
1 parent 5643c32 commit ce185c2

File tree

7 files changed

+159
-59
lines changed

7 files changed

+159
-59
lines changed

pint/_typing.py

+4
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,7 @@ def __setitem__(self, key: Any, value: Any) -> None:
6363

6464
FuncType = Callable[..., Any]
6565
F = TypeVar("F", bound=FuncType)
66+
67+
68+
# TODO: Improve or delete types
69+
QuantityArgument = Any

pint/facets/context/objects.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010

1111
import weakref
1212
from collections import ChainMap, defaultdict
13-
from typing import Any
13+
from typing import Any, Callable
1414
from collections.abc import Iterable
1515

1616
from ...facets.plain import UnitDefinition
1717
from ...util import UnitsContainer, to_units_container
1818
from .definitions import ContextDefinition
19+
from ..._typing import Magnitude
20+
21+
Transformation = Callable[
22+
[
23+
Magnitude,
24+
],
25+
Magnitude,
26+
]
1927

2028

2129
class Context:
@@ -75,14 +83,14 @@ def __init__(
7583
aliases: tuple[str] = tuple(),
7684
defaults: dict[str, Any] | None = None,
7785
) -> None:
78-
self.name = name
79-
self.aliases = aliases
86+
self.name: str | None = name
87+
self.aliases: tuple[str] = aliases
8088

8189
#: Maps (src, dst) -> transformation function
82-
self.funcs = {}
90+
self.funcs: dict[tuple[UnitsContainer, UnitsContainer], Transformation] = {}
8391

8492
#: Maps defaults variable names to values
85-
self.defaults = defaults or {}
93+
self.defaults: dict[str, Any] = defaults or {}
8694

8795
# Store Definition objects that are context-specific
8896
self.redefinitions = []
@@ -154,7 +162,9 @@ def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context:
154162

155163
return ctx
156164

157-
def add_transformation(self, src, dst, func) -> None:
165+
def add_transformation(
166+
self, src: UnitsContainer, dst: UnitsContainer, func: Transformation
167+
) -> None:
158168
"""Add a transformation function to the context."""
159169

160170
_key = self.__keytransform__(src, dst)
@@ -202,7 +212,7 @@ def _redefine(self, definition: UnitDefinition):
202212

203213
def hashable(
204214
self,
205-
) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]:
215+
) -> tuple[str | None, tuple[str], frozenset, frozenset, tuple]:
206216
"""Generate a unique hashable and comparable representation of self, which can
207217
be used as a key in a dict. This class cannot define ``__hash__`` because it is
208218
mutable, and the Python interpreter does cache the output of ``__hash__``.

pint/facets/group/objects.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,28 @@
88

99
from __future__ import annotations
1010

11+
from typing import Callable, Any, TYPE_CHECKING
12+
1113
from collections.abc import Generator, Iterable
1214
from ...util import SharedRegistryObject, getattr_maybe_raise
1315
from .definitions import GroupDefinition
1416

17+
if TYPE_CHECKING:
18+
from ..plain import UnitDefinition
19+
20+
DefineFunc = Callable[
21+
[
22+
Any,
23+
],
24+
None,
25+
]
26+
AddUnitFunc = Callable[
27+
[
28+
UnitDefinition,
29+
],
30+
None,
31+
]
32+
1533

1634
class Group(SharedRegistryObject):
1735
"""A group is a set of units.
@@ -57,7 +75,7 @@ def __init__(self, name: str):
5775
self._computed_members: frozenset[str] | None = None
5876

5977
@property
60-
def members(self):
78+
def members(self) -> frozenset[str]:
6179
"""Names of the units that are members of the group.
6280
6381
Calculated to include to all units in all included _used_groups.
@@ -143,7 +161,7 @@ def remove_groups(self, *group_names: str) -> None:
143161

144162
@classmethod
145163
def from_lines(
146-
cls, lines: Iterable[str], define_func, non_int_type: type = float
164+
cls, lines: Iterable[str], define_func: DefineFunc, non_int_type: type = float
147165
) -> Group:
148166
"""Return a Group object parsing an iterable of lines.
149167
@@ -160,11 +178,15 @@ def from_lines(
160178
161179
"""
162180
group_definition = GroupDefinition.from_lines(lines, non_int_type)
181+
182+
if group_definition is None:
183+
raise ValueError(f"Could not define group from {lines}")
184+
163185
return cls.from_definition(group_definition, define_func)
164186

165187
@classmethod
166188
def from_definition(
167-
cls, group_definition: GroupDefinition, add_unit_func=None
189+
cls, group_definition: GroupDefinition, add_unit_func: AddUnitFunc | None = None
168190
) -> Group:
169191
grp = cls(group_definition.name)
170192

pint/facets/plain/registry.py

+86-42
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from fractions import Fraction
2121
from numbers import Number
2222
from token import NAME, NUMBER
23+
from tokenize import TokenInfo
24+
2325
from typing import (
2426
TYPE_CHECKING,
2527
Any,
@@ -33,7 +35,7 @@
3335
from ..context import Context
3436
from ..._typing import Quantity, Unit
3537

36-
from ..._typing import QuantityOrUnitLike, UnitLike
38+
from ..._typing import QuantityOrUnitLike, UnitLike, QuantityArgument
3739
from ..._vendor import appdirs
3840
from ...compat import HAS_BABEL, babel_parse, tokenizer
3941
from ...errors import DimensionalityError, RedefinitionError, UndefinedUnitError
@@ -75,8 +77,10 @@
7577

7678

7779
@functools.lru_cache
78-
def pattern_to_regex(pattern):
79-
if hasattr(pattern, "finditer"):
80+
def pattern_to_regex(pattern: str | re.Pattern[str]) -> re.Pattern[str]:
81+
# TODO: This has been changed during typing improvements.
82+
# if hasattr(pattern, "finditer"):
83+
if not isinstance(pattern, str):
8084
pattern = pattern.pattern
8185

8286
# Replace "{unit_name}" match string with float regex with unit_name as group
@@ -197,7 +201,15 @@ def __init__(
197201
mpl_formatter: str = "{:P}",
198202
):
199203
#: Map a definition class to a adder methods.
200-
self._adders = {}
204+
self._adders: dict[
205+
type[T],
206+
Callable[
207+
[
208+
T,
209+
],
210+
None,
211+
],
212+
] = {}
201213
self._register_definition_adders()
202214
self._init_dynamic_classes()
203215

@@ -297,7 +309,16 @@ def _after_init(self) -> None:
297309
self._build_cache(loaded_files)
298310
self._initialized = True
299311

300-
def _register_adder(self, definition_class, adder_func):
312+
def _register_adder(
313+
self,
314+
definition_class: type[T],
315+
adder_func: Callable[
316+
[
317+
T,
318+
],
319+
None,
320+
],
321+
) -> None:
301322
"""Register a block definition."""
302323
self._adders[definition_class] = adder_func
303324

@@ -316,18 +337,18 @@ def __deepcopy__(self, memo) -> PlainRegistry:
316337
new._init_dynamic_classes()
317338
return new
318339

319-
def __getattr__(self, item):
340+
def __getattr__(self, item: str) -> Unit:
320341
getattr_maybe_raise(self, item)
321342
return self.Unit(item)
322343

323-
def __getitem__(self, item):
344+
def __getitem__(self, item: str):
324345
logger.warning(
325346
"Calling the getitem method from a UnitRegistry is deprecated. "
326347
"use `parse_expression` method or use the registry as a callable."
327348
)
328349
return self.parse_expression(item)
329350

330-
def __contains__(self, item) -> bool:
351+
def __contains__(self, item: str) -> bool:
331352
"""Support checking prefixed units with the `in` operator"""
332353
try:
333354
self.__getattr__(item)
@@ -390,7 +411,7 @@ def cache_folder(self) -> pathlib.Path | None:
390411
def non_int_type(self):
391412
return self._non_int_type
392413

393-
def define(self, definition):
414+
def define(self, definition: str | type) -> None:
394415
"""Add unit to the registry.
395416
396417
Parameters
@@ -413,7 +434,7 @@ def define(self, definition):
413434
# - then we define specific adder for each definition class. :-D
414435
############
415436

416-
def _helper_dispatch_adder(self, definition):
437+
def _helper_dispatch_adder(self, definition: Any) -> None:
417438
"""Helper function to add a single definition,
418439
choosing the appropiate method by class.
419440
"""
@@ -474,19 +495,19 @@ def _add_alias(self, definition: AliasDefinition):
474495
for alias in definition.aliases:
475496
self._helper_single_adder(alias, unit, self._units, self._units_casei)
476497

477-
def _add_dimension(self, definition: DimensionDefinition):
498+
def _add_dimension(self, definition: DimensionDefinition) -> None:
478499
self._helper_adder(definition, self._dimensions, None)
479500

480-
def _add_derived_dimension(self, definition: DerivedDimensionDefinition):
501+
def _add_derived_dimension(self, definition: DerivedDimensionDefinition) -> None:
481502
for dim_name in definition.reference.keys():
482503
if dim_name not in self._dimensions:
483504
self._add_dimension(DimensionDefinition(dim_name))
484505
self._helper_adder(definition, self._dimensions, None)
485506

486-
def _add_prefix(self, definition: PrefixDefinition):
507+
def _add_prefix(self, definition: PrefixDefinition) -> None:
487508
self._helper_adder(definition, self._prefixes, None)
488509

489-
def _add_unit(self, definition: UnitDefinition):
510+
def _add_unit(self, definition: UnitDefinition) -> None:
490511
if definition.is_base:
491512
self._base_units.append(definition.name)
492513
for dim_name in definition.reference.keys():
@@ -673,7 +694,7 @@ def _get_dimensionality_recurse(self, ref, exp, accumulator):
673694
if reg.reference is not None:
674695
self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
675696

676-
def _get_dimensionality_ratio(self, unit1, unit2):
697+
def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike):
677698
"""Get the exponential ratio between two units, i.e. solve unit2 = unit1**x for x.
678699
679700
Parameters
@@ -780,7 +801,9 @@ def _get_root_units(self, input_units, check_nonmult=True):
780801
cache[input_units] = factor, units
781802
return factor, units
782803

783-
def get_base_units(self, input_units, check_nonmult=True, system=None):
804+
def get_base_units(
805+
self, input_units: UnitsContainer | str, check_nonmult: bool = True, system=None
806+
):
784807
"""Convert unit or dict of units to the plain units.
785808
786809
If any unit is non multiplicative and check_converter is True,
@@ -1104,7 +1127,32 @@ def _parse_units(
11041127

11051128
return ret
11061129

1107-
def _eval_token(self, token, case_sensitive=None, **values):
1130+
def _eval_token(
1131+
self,
1132+
token: TokenInfo,
1133+
case_sensitive: bool | None = None,
1134+
**values: QuantityArgument,
1135+
):
1136+
"""Evaluate a single token using the following rules:
1137+
1138+
1. numerical values as strings are replaced by their numeric counterparts
1139+
- integers are parsed as integers
1140+
- other numeric values are parses of non_int_type
1141+
2. strings in (inf, infinity, nan, dimensionless) with their numerical value.
1142+
3. strings in values.keys() are replaced by Quantity(values[key])
1143+
4. in other cases, the values are parsed as units and replaced by their canonical name.
1144+
1145+
Parameters
1146+
----------
1147+
token
1148+
Token to evaluate.
1149+
case_sensitive, optional
1150+
If true, a case sensitive matching of the unit name will be done in the registry.
1151+
If false, a case INsensitive matching of the unit name will be done in the registry.
1152+
(Default value = None, which uses registry setting)
1153+
**values
1154+
Other string that will be parsed using the Quantity constructor on their corresponding value.
1155+
"""
11081156
token_type = token[0]
11091157
token_text = token[1]
11101158
if token_type == NAME:
@@ -1139,28 +1187,25 @@ def parse_pattern(
11391187
11401188
Parameters
11411189
----------
1142-
input_string :
1190+
input_string
11431191
11441192
pattern_string:
1145-
The regex parse string
1146-
case_sensitive :
1147-
(Default value = None, which uses registry setting)
1148-
many :
1193+
The regex parse string
1194+
case_sensitive, optional
1195+
If true, a case sensitive matching of the unit name will be done in the registry.
1196+
If false, a case INsensitive matching of the unit name will be done in the registry.
1197+
(Default value = None, which uses registry setting)
1198+
many, optional
11491199
Match many results
11501200
(Default value = False)
1151-
1152-
1153-
Returns
1154-
-------
1155-
11561201
"""
11571202

11581203
if not input_string:
11591204
return [] if many else None
11601205

11611206
# Parse string
1162-
pattern = pattern_to_regex(pattern)
1163-
matched = re.finditer(pattern, input_string)
1207+
regex = pattern_to_regex(pattern)
1208+
matched = re.finditer(regex, input_string)
11641209

11651210
# Extract result(s)
11661211
results = []
@@ -1196,16 +1241,14 @@ def parse_expression(
11961241
11971242
Parameters
11981243
----------
1199-
input_string :
1200-
1201-
case_sensitive :
1202-
(Default value = None, which uses registry setting)
1203-
**values :
1204-
1205-
1206-
Returns
1207-
-------
1208-
1244+
input_string
1245+
1246+
case_sensitive, optional
1247+
If true, a case sensitive matching of the unit name will be done in the registry.
1248+
If false, a case INsensitive matching of the unit name will be done in the registry.
1249+
(Default value = None, which uses registry setting)
1250+
**values
1251+
Other string that will be parsed using the Quantity constructor on their corresponding value.
12091252
"""
12101253
if not input_string:
12111254
return self.Quantity(1)
@@ -1215,8 +1258,9 @@ def parse_expression(
12151258
input_string = string_preprocessor(input_string)
12161259
gen = tokenizer(input_string)
12171260

1218-
return build_eval_tree(gen).evaluate(
1219-
lambda x: self._eval_token(x, case_sensitive=case_sensitive, **values)
1220-
)
1261+
def _define_op(s: str):
1262+
return self._eval_token(s, case_sensitive=case_sensitive, **values)
1263+
1264+
return build_eval_tree(gen).evaluate(_define_op)
12211265

12221266
__call__ = parse_expression

0 commit comments

Comments
 (0)