Skip to content

Commit fe61406

Browse files
authored
implement drop_sel (#73)
* add indexer comparison functions * move the indexer conversion tests into a new test group * add support for comparing DataArray and Variable indexers * add a function to convert indexers * implement drop_sel * add drop_sel to api.rst * update whats-new.rst [skip-ci] * move drop_sel below loc * fix the docstrings of drop_sel
1 parent 7ea9935 commit fe61406

File tree

7 files changed

+455
-3
lines changed

7 files changed

+455
-3
lines changed

docs/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Dataset
1919
xarray.Dataset.pint.interp_like
2020
xarray.Dataset.pint.reindex
2121
xarray.Dataset.pint.reindex_like
22+
xarray.Dataset.pint.drop_sel
2223
xarray.Dataset.pint.sel
2324
xarray.Dataset.pint.to
2425

@@ -43,6 +44,7 @@ DataArray
4344
xarray.DataArray.pint.interp_like
4445
xarray.DataArray.pint.reindex
4546
xarray.DataArray.pint.reindex_like
47+
xarray.DataArray.pint.drop_sel
4648
xarray.DataArray.pint.sel
4749
xarray.DataArray.pint.to
4850

docs/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ What's new
2525
By `Mika Pflüger <https://github.com/mikapfl>`_.
2626
- implement :py:meth:`Dataset.pint.sel` and :py:meth:`DataArray.pint.sel` (:pull:`60`).
2727
By `Justus Magin <https://github.com/keewis>`_.
28+
- implement :py:meth:`Dataset.pint.drop_sel` and :py:meth:`DataArray.pint.drop_sel` (:pull:`73`).
29+
By `Justus Magin <https://github.com/keewis>`_.
2830
- implement :py:meth:`Dataset.pint.reindex`, :py:meth:`Dataset.pint.reindex_like`,
2931
:py:meth:`DataArray.pint.reindex` and :py:meth:`DataArray.pint.reindex_like` (:pull:`69`).
3032
By `Justus Magin <https://github.com/keewis>`_.

pint_xarray/accessors.py

+120
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,66 @@ def sel(
732732
def loc(self):
733733
...
734734

735+
def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
736+
"""unit-aware version of drop_sel
737+
738+
Just like :py:meth:`xarray.DataArray.drop_sel`, except the indexers are converted
739+
to the units of the object's indexes first.
740+
741+
See Also
742+
--------
743+
xarray.Dataset.pint.drop_sel
744+
xarray.DataArray.drop_sel
745+
xarray.Dataset.drop_sel
746+
"""
747+
indexers = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel")
748+
749+
indexer_units = {
750+
name: conversion.extract_indexer_units(indexer)
751+
for name, indexer in indexers.items()
752+
}
753+
754+
# make sure we only have compatible units
755+
dims = self.da.dims
756+
unit_attrs = conversion.extract_unit_attributes(self.da)
757+
index_units = {
758+
name: units for name, units in unit_attrs.items() if name in dims
759+
}
760+
761+
registry = get_registry(None, index_units, indexer_units)
762+
763+
units = zip_mappings(indexer_units, index_units)
764+
incompatible_units = [
765+
key
766+
for key, (indexer_unit, index_unit) in units.items()
767+
if (
768+
None not in (indexer_unit, index_unit)
769+
and not registry.is_compatible_with(indexer_unit, index_unit)
770+
)
771+
]
772+
if incompatible_units:
773+
units1 = {key: indexer_units[key] for key in incompatible_units}
774+
units2 = {key: index_units[key] for key in incompatible_units}
775+
raise DimensionalityError(units1, units2)
776+
777+
# convert the indexers to the indexes units
778+
converted_indexers = {
779+
name: conversion.convert_indexer_units(indexer, index_units[name])
780+
for name, indexer in indexers.items()
781+
}
782+
783+
# index
784+
stripped_indexers = {
785+
name: conversion.strip_indexer_units(indexer)
786+
for name, indexer in converted_indexers.items()
787+
}
788+
indexed = self.da.drop_sel(
789+
stripped_indexers,
790+
errors=errors,
791+
)
792+
793+
return indexed
794+
735795

736796
@register_dataset_accessor("pint")
737797
class PintDatasetAccessor:
@@ -1321,3 +1381,63 @@ def sel(
13211381
@property
13221382
def loc(self):
13231383
...
1384+
1385+
def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs):
1386+
"""unit-aware version of drop_sel
1387+
1388+
Just like :py:meth:`xarray.Dataset.drop_sel`, except the indexers are converted
1389+
to the units of the object's indexes first.
1390+
1391+
See Also
1392+
--------
1393+
xarray.DataArray.pint.drop_sel
1394+
xarray.Dataset.drop_sel
1395+
xarray.DataArray.drop_sel
1396+
"""
1397+
indexers = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel")
1398+
1399+
indexer_units = {
1400+
name: conversion.extract_indexer_units(indexer)
1401+
for name, indexer in indexers.items()
1402+
}
1403+
1404+
# make sure we only have compatible units
1405+
dims = self.ds.dims
1406+
unit_attrs = conversion.extract_unit_attributes(self.ds)
1407+
index_units = {
1408+
name: units for name, units in unit_attrs.items() if name in dims
1409+
}
1410+
1411+
registry = get_registry(None, index_units, indexer_units)
1412+
1413+
units = zip_mappings(indexer_units, index_units)
1414+
incompatible_units = [
1415+
key
1416+
for key, (indexer_unit, index_unit) in units.items()
1417+
if (
1418+
None not in (indexer_unit, index_unit)
1419+
and not registry.is_compatible_with(indexer_unit, index_unit)
1420+
)
1421+
]
1422+
if incompatible_units:
1423+
units1 = {key: indexer_units[key] for key in incompatible_units}
1424+
units2 = {key: index_units[key] for key in incompatible_units}
1425+
raise DimensionalityError(units1, units2)
1426+
1427+
# convert the indexers to the indexes units
1428+
converted_indexers = {
1429+
name: conversion.convert_indexer_units(indexer, index_units[name])
1430+
for name, indexer in indexers.items()
1431+
}
1432+
1433+
# index
1434+
stripped_indexers = {
1435+
name: conversion.strip_indexer_units(indexer)
1436+
for name, indexer in converted_indexers.items()
1437+
}
1438+
indexed = self.ds.drop_sel(
1439+
stripped_indexers,
1440+
errors=errors,
1441+
)
1442+
1443+
return indexed

pint_xarray/conversion.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from xarray import DataArray, Dataset, IndexVariable, Variable
55

66
unit_attribute_name = "units"
7+
slice_attributes = ("start", "stop", "step")
78

89

910
def array_attach_units(data, unit):
@@ -306,7 +307,7 @@ def strip_unit_attributes(obj, attr="units"):
306307

307308

308309
def slice_extract_units(indexer):
309-
elements = {name: getattr(indexer, name) for name in ("start", "stop", "step")}
310+
elements = {name: getattr(indexer, name) for name in slice_attributes}
310311
extracted_units = [
311312
array_extract_units(value)
312313
for name, value in elements.items()
@@ -333,6 +334,28 @@ def slice_extract_units(indexer):
333334
return registry.Quantity(1, units_).to_base_units().units
334335

335336

337+
def convert_units_slice(indexer, units):
338+
attrs = {name: getattr(indexer, name) for name in slice_attributes}
339+
converted = {
340+
name: array_convert_units(value, units) if value is not None else None
341+
for name, value in attrs.items()
342+
}
343+
args = [converted[name] for name in slice_attributes]
344+
345+
return slice(*args)
346+
347+
348+
def convert_indexer_units(indexer, units):
349+
if isinstance(indexer, slice):
350+
return convert_units_slice(indexer, units)
351+
elif isinstance(indexer, DataArray):
352+
return convert_units(indexer, {None: units})
353+
elif isinstance(indexer, Variable):
354+
return convert_units_variable(indexer, units)
355+
else:
356+
return array_convert_units(indexer, units)
357+
358+
336359
def extract_indexer_units(indexer):
337360
if isinstance(indexer, slice):
338361
return slice_extract_units(indexer)

pint_xarray/tests/test_accessors.py

+143
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,149 @@ def test_sel(obj, indexers, expected, error):
418418
assert_identical(actual, expected)
419419

420420

421+
@pytest.mark.parametrize(
422+
["obj", "indexers", "expected", "error"],
423+
(
424+
pytest.param(
425+
xr.Dataset(
426+
{
427+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
428+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
429+
}
430+
),
431+
{"x": Quantity([10, 30], "dm"), "y": Quantity([60], "s")},
432+
xr.Dataset(
433+
{
434+
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
435+
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
436+
}
437+
),
438+
None,
439+
id="Dataset-identical units",
440+
),
441+
pytest.param(
442+
xr.Dataset(
443+
{
444+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
445+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
446+
}
447+
),
448+
{"x": Quantity([1, 3], "m"), "y": Quantity([1], "min")},
449+
xr.Dataset(
450+
{
451+
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
452+
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
453+
}
454+
),
455+
None,
456+
id="Dataset-compatible units",
457+
),
458+
pytest.param(
459+
xr.Dataset(
460+
{
461+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
462+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
463+
}
464+
),
465+
{"x": Quantity([1, 3], "s"), "y": Quantity([1], "m")},
466+
None,
467+
DimensionalityError,
468+
id="Dataset-incompatible units",
469+
),
470+
pytest.param(
471+
xr.Dataset(
472+
{
473+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
474+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
475+
}
476+
),
477+
{"x": Quantity([10, 30], "m"), "y": Quantity([60], "min")},
478+
None,
479+
KeyError,
480+
id="Dataset-compatible units-not found",
481+
),
482+
pytest.param(
483+
xr.DataArray(
484+
[[0, 1], [2, 3], [4, 5]],
485+
dims=("x", "y"),
486+
coords={
487+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
488+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
489+
},
490+
),
491+
{"x": Quantity([10, 30], "dm"), "y": Quantity([60], "s")},
492+
xr.DataArray(
493+
[[3]],
494+
dims=("x", "y"),
495+
coords={
496+
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
497+
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
498+
},
499+
),
500+
None,
501+
id="DataArray-identical units",
502+
),
503+
pytest.param(
504+
xr.DataArray(
505+
[[0, 1], [2, 3], [4, 5]],
506+
dims=("x", "y"),
507+
coords={
508+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
509+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
510+
},
511+
),
512+
{"x": Quantity([1, 3], "m"), "y": Quantity([1], "min")},
513+
xr.DataArray(
514+
[[3]],
515+
dims=("x", "y"),
516+
coords={
517+
"x": ("x", [20], {"units": unit_registry.Unit("dm")}),
518+
"y": ("y", [120], {"units": unit_registry.Unit("s")}),
519+
},
520+
),
521+
None,
522+
id="DataArray-compatible units",
523+
),
524+
pytest.param(
525+
xr.DataArray(
526+
[[0, 1], [2, 3], [4, 5]],
527+
dims=("x", "y"),
528+
coords={
529+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
530+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
531+
},
532+
),
533+
{"x": Quantity([10, 30], "s"), "y": Quantity([60], "m")},
534+
None,
535+
DimensionalityError,
536+
id="DataArray-incompatible units",
537+
),
538+
pytest.param(
539+
xr.DataArray(
540+
[[0, 1], [2, 3], [4, 5]],
541+
dims=("x", "y"),
542+
coords={
543+
"x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}),
544+
"y": ("y", [60, 120], {"units": unit_registry.Unit("s")}),
545+
},
546+
),
547+
{"x": Quantity([10, 30], "m"), "y": Quantity([60], "min")},
548+
None,
549+
KeyError,
550+
id="DataArray-compatible units-not found",
551+
),
552+
),
553+
)
554+
def test_drop_sel(obj, indexers, expected, error):
555+
if error is not None:
556+
with pytest.raises(error):
557+
obj.pint.drop_sel(indexers)
558+
else:
559+
actual = obj.pint.drop_sel(indexers)
560+
assert_units_equal(actual, expected)
561+
assert_identical(actual, expected)
562+
563+
421564
@pytest.mark.parametrize(
422565
["obj", "indexers", "expected", "error"],
423566
(

0 commit comments

Comments
 (0)