Skip to content

Commit 038ac60

Browse files
authored
implement loc (#79)
* implement loc * fix loc.__setitem__ * add tests for loc.__getitem__ * add tests for loc.__setitem__ * expect KeyError instead of DimensionalityError for incompatible units * add loc to api.rst * add docstrings * update whats-new.rst [skip-ci] * fix the link to DataArray.pint.loc [skip-ci]
1 parent 901fa8b commit 038ac60

File tree

4 files changed

+459
-10
lines changed

4 files changed

+459
-10
lines changed

docs/api.rst

+8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ This page contains a auto-generated summary of ``pint-xarray``'s API.
99

1010
Dataset
1111
-------
12+
.. autosummary::
13+
:toctree: generated/
14+
:template: autosummary/accessor_attribute.rst
15+
16+
xarray.Dataset.pint.loc
17+
1218
.. autosummary::
1319
:toctree: generated/
1420
:template: autosummary/accessor_method.rst
@@ -31,6 +37,8 @@ DataArray
3137
:toctree: generated/
3238
:template: autosummary/accessor_attribute.rst
3339

40+
xarray.Dataset.pint.loc
41+
3442
xarray.DataArray.pint.magnitude
3543
xarray.DataArray.pint.units
3644
xarray.DataArray.pint.dimensionality

docs/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ What's new
2525
By `Mika Pflüger <https://github.com/mikapfl>`_.
2626
- implement :py:meth:`Dataset.pint.sel` and :py:meth:`DataArray.pint.sel` (:pull:`60`).
2727
By `Justus Magin <https://github.com/keewis>`_.
28+
- implement :py:attr:`Dataset.pint.loc` and :py:attr:`DataArray.pint.loc` (:pull:`79`).
29+
By `Justus Magin <https://github.com/keewis>`_.
2830
- implement :py:meth:`Dataset.pint.drop_sel` and :py:meth:`DataArray.pint.drop_sel` (:pull:`73`).
2931
By `Justus Magin <https://github.com/keewis>`_.
3032
- implement :py:meth:`Dataset.pint.reindex`, :py:meth:`Dataset.pint.reindex_like`,

pint_xarray/accessors.py

+196-8
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,164 @@ def _decide_units(units, registry, unit_attribute):
147147
return units
148148

149149

150+
class DatasetLocIndexer:
151+
__slots__ = ("ds",)
152+
153+
def __init__(self, ds):
154+
self.ds = ds
155+
156+
def __getitem__(self, indexers):
157+
if not is_dict_like(indexers):
158+
raise NotImplementedError("pandas-style indexing is not supported, yet")
159+
160+
indexer_units = {
161+
name: conversion.extract_indexer_units(indexer)
162+
for name, indexer in indexers.items()
163+
}
164+
165+
# make sure we only have compatible units
166+
dims = self.ds.dims
167+
unit_attrs = conversion.extract_unit_attributes(self.ds)
168+
index_units = {
169+
name: units for name, units in unit_attrs.items() if name in dims
170+
}
171+
172+
registry = get_registry(None, index_units, indexer_units)
173+
174+
units = zip_mappings(indexer_units, index_units)
175+
incompatible_units = [
176+
key
177+
for key, (indexer_unit, index_unit) in units.items()
178+
if (
179+
None not in (indexer_unit, index_unit)
180+
and not registry.is_compatible_with(indexer_unit, index_unit)
181+
)
182+
]
183+
if incompatible_units:
184+
raise KeyError(
185+
"not all values found in "
186+
+ (
187+
f"index {incompatible_units[0]!r}"
188+
if len(incompatible_units) == 1
189+
else f"indexes {', '.join(repr(_) for _ in incompatible_units)}"
190+
)
191+
)
192+
193+
# convert the indexes to the indexer's units
194+
converted = conversion.convert_units(self.ds, indexer_units)
195+
196+
# index
197+
stripped_indexers = {
198+
name: conversion.strip_indexer_units(indexer)
199+
for name, indexer in indexers.items()
200+
}
201+
return converted.loc[stripped_indexers]
202+
203+
204+
class DataArrayLocIndexer:
205+
__slots__ = ("da",)
206+
207+
def __init__(self, da):
208+
self.da = da
209+
210+
def __getitem__(self, indexers):
211+
if not is_dict_like(indexers):
212+
raise NotImplementedError("pandas-style indexing is not supported, yet")
213+
214+
indexer_units = {
215+
name: conversion.extract_indexer_units(indexer)
216+
for name, indexer in indexers.items()
217+
}
218+
219+
# make sure we only have compatible units
220+
dims = self.da.dims
221+
unit_attrs = conversion.extract_unit_attributes(self.da)
222+
index_units = {
223+
name: units for name, units in unit_attrs.items() if name in dims
224+
}
225+
226+
registry = get_registry(None, index_units, indexer_units)
227+
228+
units = zip_mappings(indexer_units, index_units)
229+
incompatible_units = [
230+
key
231+
for key, (indexer_unit, index_unit) in units.items()
232+
if (
233+
None not in (indexer_unit, index_unit)
234+
and not registry.is_compatible_with(indexer_unit, index_unit)
235+
)
236+
]
237+
if incompatible_units:
238+
raise KeyError(
239+
"not all values found in "
240+
+ (
241+
f"index {incompatible_units[0]!r}"
242+
if len(incompatible_units) == 1
243+
else f"indexes {', '.join(repr(_) for _ in incompatible_units)}"
244+
)
245+
)
246+
247+
# convert the indexes to the indexer's units
248+
converted = conversion.convert_units(self.da, indexer_units)
249+
250+
# index
251+
stripped_indexers = {
252+
name: conversion.strip_indexer_units(indexer)
253+
for name, indexer in indexers.items()
254+
}
255+
return converted.loc[stripped_indexers]
256+
257+
def __setitem__(self, indexers, values):
258+
if not is_dict_like(indexers):
259+
raise NotImplementedError("pandas-style indexing is not supported, yet")
260+
261+
indexer_units = {
262+
name: conversion.extract_indexer_units(indexer)
263+
for name, indexer in indexers.items()
264+
}
265+
266+
# make sure we only have compatible units
267+
dims = self.da.dims
268+
unit_attrs = conversion.extract_unit_attributes(self.da)
269+
index_units = {
270+
name: units for name, units in unit_attrs.items() if name in dims
271+
}
272+
273+
registry = get_registry(None, index_units, indexer_units)
274+
275+
units = zip_mappings(indexer_units, index_units)
276+
incompatible_units = [
277+
key
278+
for key, (indexer_unit, index_unit) in units.items()
279+
if (
280+
None not in (indexer_unit, index_unit)
281+
and not registry.is_compatible_with(indexer_unit, index_unit)
282+
)
283+
]
284+
if incompatible_units:
285+
raise KeyError(
286+
"not all values found in "
287+
+ (
288+
f"index {incompatible_units[0]!r}"
289+
if len(incompatible_units) == 1
290+
else f"indexes {', '.join(repr(_) for _ in incompatible_units)}"
291+
)
292+
)
293+
294+
# convert the indexers to the index units
295+
converted = {
296+
name: conversion.convert_indexer_units(indexer, index_units[name])
297+
for name, indexer in indexers.items()
298+
}
299+
300+
# index
301+
stripped_indexers = {
302+
name: conversion.strip_indexer_units(indexer)
303+
for name, indexer in converted.items()
304+
}
305+
self.da.loc[stripped_indexers] = values
306+
307+
150308
@register_dataarray_accessor("pint")
151309
class PintDataArrayAccessor:
152310
"""
@@ -707,9 +865,14 @@ def sel(
707865
)
708866
]
709867
if incompatible_units:
710-
units1 = {key: indexer_units[key] for key in incompatible_units}
711-
units2 = {key: index_units[key] for key in incompatible_units}
712-
raise DimensionalityError(units1, units2)
868+
raise KeyError(
869+
"not all values found in "
870+
+ (
871+
f"index {incompatible_units[0]!r}"
872+
if len(incompatible_units) == 1
873+
else f"indexes {', '.join(repr(_) for _ in incompatible_units)}"
874+
)
875+
)
713876

714877
# convert the indexes to the indexer's units
715878
converted = conversion.convert_units(self.da, indexer_units)
@@ -730,7 +893,16 @@ def sel(
730893

731894
@property
732895
def loc(self):
733-
...
896+
"""Unit-aware attribute for indexing
897+
898+
.. note::
899+
Position based indexing (e.g. ``ds.loc[1, 2:]``) is not supported, yet
900+
901+
See Also
902+
--------
903+
xarray.DataArray.loc
904+
"""
905+
return DataArrayLocIndexer(self.da)
734906

735907
def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
736908
"""unit-aware version of drop_sel
@@ -1391,9 +1563,14 @@ def sel(
13911563
)
13921564
]
13931565
if incompatible_units:
1394-
units1 = {key: indexer_units[key] for key in incompatible_units}
1395-
units2 = {key: index_units[key] for key in incompatible_units}
1396-
raise DimensionalityError(units1, units2)
1566+
raise KeyError(
1567+
"not all values found in "
1568+
+ (
1569+
f"index {incompatible_units[0]!r}"
1570+
if len(incompatible_units) == 1
1571+
else f"indexes {', '.join(repr(_) for _ in incompatible_units)}"
1572+
)
1573+
)
13971574

13981575
# convert the indexes to the indexer's units
13991576
converted = conversion.convert_units(self.ds, indexer_units)
@@ -1414,7 +1591,18 @@ def sel(
14141591

14151592
@property
14161593
def loc(self):
1417-
...
1594+
"""Unit-aware attribute for indexing
1595+
1596+
Only supports ``__getitem__``.
1597+
1598+
.. note::
1599+
Position based indexing (e.g. ``ds.loc[1, 2:]``) is not supported, yet
1600+
1601+
See Also
1602+
--------
1603+
xarray.Dataset.loc
1604+
"""
1605+
return DatasetLocIndexer(self.ds)
14181606

14191607
def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
14201608
"""unit-aware version of drop_sel

0 commit comments

Comments
 (0)