diff --git a/docs/api.rst b/docs/api.rst index 12c42a6f..ca725e98 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,6 +19,7 @@ Dataset xarray.Dataset.pint.interp_like xarray.Dataset.pint.reindex xarray.Dataset.pint.reindex_like + xarray.Dataset.pint.drop_sel xarray.Dataset.pint.sel xarray.Dataset.pint.to @@ -43,6 +44,7 @@ DataArray xarray.DataArray.pint.interp_like xarray.DataArray.pint.reindex xarray.DataArray.pint.reindex_like + xarray.DataArray.pint.drop_sel xarray.DataArray.pint.sel xarray.DataArray.pint.to diff --git a/docs/whats-new.rst b/docs/whats-new.rst index eb3746f6..d44daa84 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -25,6 +25,8 @@ What's new By `Mika Pflüger `_. - implement :py:meth:`Dataset.pint.sel` and :py:meth:`DataArray.pint.sel` (:pull:`60`). By `Justus Magin `_. +- implement :py:meth:`Dataset.pint.drop_sel` and :py:meth:`DataArray.pint.drop_sel` (:pull:`73`). + By `Justus Magin `_. - implement :py:meth:`Dataset.pint.reindex`, :py:meth:`Dataset.pint.reindex_like`, :py:meth:`DataArray.pint.reindex` and :py:meth:`DataArray.pint.reindex_like` (:pull:`69`). By `Justus Magin `_. diff --git a/pint_xarray/accessors.py b/pint_xarray/accessors.py index 0e70acb6..4784cdd9 100644 --- a/pint_xarray/accessors.py +++ b/pint_xarray/accessors.py @@ -734,6 +734,66 @@ def sel( def loc(self): ... + def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): + """unit-aware version of drop_sel + + Just like :py:meth:`xarray.DataArray.drop_sel`, except the indexers are converted + to the units of the object's indexes first. + + See Also + -------- + xarray.Dataset.pint.drop_sel + xarray.DataArray.drop_sel + xarray.Dataset.drop_sel + """ + indexers = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel") + + indexer_units = { + name: conversion.extract_indexer_units(indexer) + for name, indexer in indexers.items() + } + + # make sure we only have compatible units + dims = self.da.dims + unit_attrs = conversion.extract_unit_attributes(self.da) + index_units = { + name: units for name, units in unit_attrs.items() if name in dims + } + + registry = get_registry(None, index_units, indexer_units) + + units = zip_mappings(indexer_units, index_units) + incompatible_units = [ + key + for key, (indexer_unit, index_unit) in units.items() + if ( + None not in (indexer_unit, index_unit) + and not registry.is_compatible_with(indexer_unit, index_unit) + ) + ] + if incompatible_units: + units1 = {key: indexer_units[key] for key in incompatible_units} + units2 = {key: index_units[key] for key in incompatible_units} + raise DimensionalityError(units1, units2) + + # convert the indexers to the indexes units + converted_indexers = { + name: conversion.convert_indexer_units(indexer, index_units[name]) + for name, indexer in indexers.items() + } + + # index + stripped_indexers = { + name: conversion.strip_indexer_units(indexer) + for name, indexer in converted_indexers.items() + } + indexed = self.da.drop_sel( + stripped_indexers, + errors=errors, + ) + + return indexed + @register_dataset_accessor("pint") class PintDatasetAccessor: @@ -1325,3 +1385,63 @@ def sel( @property def loc(self): ... + + def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): + """unit-aware version of drop_sel + + Just like :py:meth:`xarray.Dataset.drop_sel`, except the indexers are converted + to the units of the object's indexes first. + + See Also + -------- + xarray.DataArray.pint.drop_sel + xarray.Dataset.drop_sel + xarray.DataArray.drop_sel + """ + indexers = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel") + + indexer_units = { + name: conversion.extract_indexer_units(indexer) + for name, indexer in indexers.items() + } + + # make sure we only have compatible units + dims = self.ds.dims + unit_attrs = conversion.extract_unit_attributes(self.ds) + index_units = { + name: units for name, units in unit_attrs.items() if name in dims + } + + registry = get_registry(None, index_units, indexer_units) + + units = zip_mappings(indexer_units, index_units) + incompatible_units = [ + key + for key, (indexer_unit, index_unit) in units.items() + if ( + None not in (indexer_unit, index_unit) + and not registry.is_compatible_with(indexer_unit, index_unit) + ) + ] + if incompatible_units: + units1 = {key: indexer_units[key] for key in incompatible_units} + units2 = {key: index_units[key] for key in incompatible_units} + raise DimensionalityError(units1, units2) + + # convert the indexers to the indexes units + converted_indexers = { + name: conversion.convert_indexer_units(indexer, index_units[name]) + for name, indexer in indexers.items() + } + + # index + stripped_indexers = { + name: conversion.strip_indexer_units(indexer) + for name, indexer in converted_indexers.items() + } + indexed = self.ds.drop_sel( + stripped_indexers, + errors=errors, + ) + + return indexed diff --git a/pint_xarray/conversion.py b/pint_xarray/conversion.py index 4c12d0b1..6545c76b 100644 --- a/pint_xarray/conversion.py +++ b/pint_xarray/conversion.py @@ -4,6 +4,7 @@ from xarray import DataArray, Dataset, IndexVariable, Variable unit_attribute_name = "units" +slice_attributes = ("start", "stop", "step") def array_attach_units(data, unit): @@ -306,7 +307,7 @@ def strip_unit_attributes(obj, attr="units"): def slice_extract_units(indexer): - elements = {name: getattr(indexer, name) for name in ("start", "stop", "step")} + elements = {name: getattr(indexer, name) for name in slice_attributes} extracted_units = [ array_extract_units(value) for name, value in elements.items() @@ -333,6 +334,28 @@ def slice_extract_units(indexer): return registry.Quantity(1, units_).to_base_units().units +def convert_units_slice(indexer, units): + attrs = {name: getattr(indexer, name) for name in slice_attributes} + converted = { + name: array_convert_units(value, units) if value is not None else None + for name, value in attrs.items() + } + args = [converted[name] for name in slice_attributes] + + return slice(*args) + + +def convert_indexer_units(indexer, units): + if isinstance(indexer, slice): + return convert_units_slice(indexer, units) + elif isinstance(indexer, DataArray): + return convert_units(indexer, {None: units}) + elif isinstance(indexer, Variable): + return convert_units_variable(indexer, units) + else: + return array_convert_units(indexer, units) + + def extract_indexer_units(indexer): if isinstance(indexer, slice): return slice_extract_units(indexer) diff --git a/pint_xarray/tests/test_accessors.py b/pint_xarray/tests/test_accessors.py index 94b632f4..df4bd4d1 100644 --- a/pint_xarray/tests/test_accessors.py +++ b/pint_xarray/tests/test_accessors.py @@ -418,6 +418,149 @@ def test_sel(obj, indexers, expected, error): assert_identical(actual, expected) +@pytest.mark.parametrize( + ["obj", "indexers", "expected", "error"], + ( + pytest.param( + xr.Dataset( + { + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + } + ), + {"x": Quantity([10, 30], "dm"), "y": Quantity([60], "s")}, + xr.Dataset( + { + "x": ("x", [20], {"units": unit_registry.Unit("dm")}), + "y": ("y", [120], {"units": unit_registry.Unit("s")}), + } + ), + None, + id="Dataset-identical units", + ), + pytest.param( + xr.Dataset( + { + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + } + ), + {"x": Quantity([1, 3], "m"), "y": Quantity([1], "min")}, + xr.Dataset( + { + "x": ("x", [20], {"units": unit_registry.Unit("dm")}), + "y": ("y", [120], {"units": unit_registry.Unit("s")}), + } + ), + None, + id="Dataset-compatible units", + ), + pytest.param( + xr.Dataset( + { + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + } + ), + {"x": Quantity([1, 3], "s"), "y": Quantity([1], "m")}, + None, + DimensionalityError, + id="Dataset-incompatible units", + ), + pytest.param( + xr.Dataset( + { + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + } + ), + {"x": Quantity([10, 30], "m"), "y": Quantity([60], "min")}, + None, + KeyError, + id="Dataset-compatible units-not found", + ), + pytest.param( + xr.DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + coords={ + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + }, + ), + {"x": Quantity([10, 30], "dm"), "y": Quantity([60], "s")}, + xr.DataArray( + [[3]], + dims=("x", "y"), + coords={ + "x": ("x", [20], {"units": unit_registry.Unit("dm")}), + "y": ("y", [120], {"units": unit_registry.Unit("s")}), + }, + ), + None, + id="DataArray-identical units", + ), + pytest.param( + xr.DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + coords={ + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + }, + ), + {"x": Quantity([1, 3], "m"), "y": Quantity([1], "min")}, + xr.DataArray( + [[3]], + dims=("x", "y"), + coords={ + "x": ("x", [20], {"units": unit_registry.Unit("dm")}), + "y": ("y", [120], {"units": unit_registry.Unit("s")}), + }, + ), + None, + id="DataArray-compatible units", + ), + pytest.param( + xr.DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + coords={ + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + }, + ), + {"x": Quantity([10, 30], "s"), "y": Quantity([60], "m")}, + None, + DimensionalityError, + id="DataArray-incompatible units", + ), + pytest.param( + xr.DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + coords={ + "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), + "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + }, + ), + {"x": Quantity([10, 30], "m"), "y": Quantity([60], "min")}, + None, + KeyError, + id="DataArray-compatible units-not found", + ), + ), +) +def test_drop_sel(obj, indexers, expected, error): + if error is not None: + with pytest.raises(error): + obj.pint.drop_sel(indexers) + else: + actual = obj.pint.drop_sel(indexers) + assert_units_equal(actual, expected) + assert_identical(actual, expected) + + @pytest.mark.parametrize( ["obj", "indexers", "expected", "error"], ( diff --git a/pint_xarray/tests/test_conversion.py b/pint_xarray/tests/test_conversion.py index b7cb782b..10922d6d 100644 --- a/pint_xarray/tests/test_conversion.py +++ b/pint_xarray/tests/test_conversion.py @@ -5,7 +5,13 @@ from pint_xarray import conversion -from .utils import assert_array_equal, assert_array_units_equal, assert_identical +from .utils import ( + assert_array_equal, + assert_array_units_equal, + assert_identical, + assert_indexer_equal, + assert_indexer_units_equal, +) unit_registry = pint.UnitRegistry() Quantity = unit_registry.Quantity @@ -535,6 +541,109 @@ def test_strip_unit_attributes(self, obj, expected): filter_none_values(conversion.extract_unit_attributes(actual)) == expected ) + +class TestIndexerFunctions: + @pytest.mark.parametrize( + ["indexer", "units", "expected", "error"], + ( + pytest.param(1, None, 1, None, id="scalar-no units"), + pytest.param( + 1, + "dimensionless", + Quantity(1, "dimensionless"), + ValueError, + id="scalar-dimensionless", + ), + pytest.param( + Quantity(1, "m"), + Unit("dm"), + Quantity(10, "dm"), + None, + id="scalar-units", + ), + pytest.param( + np.array([1, 2]), None, np.array([1, 2]), None, id="array-no units" + ), + pytest.param( + Quantity([1, 2], "m"), + Unit("dm"), + Quantity([10, 20], "dm"), + None, + id="array-units", + ), + pytest.param( + Variable("x", [1, 2]), + None, + Variable("x", [1, 2]), + None, + id="Variable-no units", + ), + pytest.param( + Variable("x", Quantity([1, 2], "m")), + Unit("dm"), + Variable("x", Quantity([10, 20], "dm")), + None, + id="Variable-units", + ), + pytest.param( + DataArray([1, 2], dims="x"), + None, + DataArray([1, 2], dims="x"), + None, + id="DataArray-no units", + ), + pytest.param( + DataArray(Quantity([1, 2], "m"), dims="x"), + Unit("dm"), + DataArray(Quantity([10, 20], "dm"), dims="x"), + None, + id="DataArray-units", + ), + pytest.param( + slice(None), None, slice(None), None, id="empty slice-no units" + ), + pytest.param( + slice(1, None), None, slice(1, None), None, id="slice-no units" + ), + pytest.param( + slice(Quantity(1, "m"), Quantity(2, "m")), + Unit("m"), + slice(Quantity(1, "m"), Quantity(2, "m")), + None, + id="slice-identical units", + ), + pytest.param( + slice(Quantity(1, "m"), Quantity(2000, "mm")), + Unit("dm"), + slice(Quantity(10, "dm"), Quantity(20, "dm")), + None, + id="slice-compatible units", + ), + pytest.param( + slice(Quantity(1, "m"), Quantity(2, "m")), + Unit("ms"), + None, + pint.DimensionalityError, + id="slice-incompatible units", + ), + pytest.param( + slice(1000, Quantity(2000, "ms")), + Unit("s"), + None, + pint.DimensionalityError, + id="slice-incompatible units-mixed", + ), + ), + ) + def test_convert_indexer_units(self, indexer, units, expected, error): + if error is not None: + with pytest.raises(error): + conversion.convert_indexer_units(indexer, units) + else: + actual = conversion.convert_indexer_units(indexer, units) + assert_indexer_equal(actual, expected) + assert_indexer_units_equal(actual, expected) + @pytest.mark.parametrize( ["indexer", "expected"], ( diff --git a/pint_xarray/tests/utils.py b/pint_xarray/tests/utils.py index 2fda68a5..d51cb0c5 100644 --- a/pint_xarray/tests/utils.py +++ b/pint_xarray/tests/utils.py @@ -3,9 +3,17 @@ import numpy as np import pytest +from pint import Quantity +from xarray import DataArray, Variable from xarray.testing import assert_equal, assert_identical # noqa: F401 -from ..conversion import extract_units +from ..conversion import ( + array_strip_units, + extract_indexer_units, + extract_units, + strip_units, + strip_units_variable, +) @contextmanager @@ -38,6 +46,51 @@ def assert_array_equal(a, b): np.testing.assert_array_equal(a_, b_) +def assert_slice_equal(a, b): + attrs = ("start", "stop", "step") + values_a = tuple(getattr(a, name) for name in attrs) + values_b = tuple(getattr(b, name) for name in attrs) + stripped_a = tuple(array_strip_units(v) for v in values_a) + stripped_b = tuple(array_strip_units(v) for v in values_b) + + assert ( + stripped_a == stripped_b + ), f"different values: {stripped_a!r} ←→ {stripped_b!r}" + + +def assert_indexer_equal(a, b): + __tracebackhide__ = True + + assert type(a) == type(b) + if isinstance(a, slice): + assert_slice_equal(a, b) + elif isinstance(a, DataArray): + stripped_a = strip_units(a) + stripped_b = strip_units(b) + + assert_equal(stripped_a, stripped_b) + elif isinstance(a, Variable): + stripped_a = strip_units_variable(a) + stripped_b = strip_units_variable(b) + + assert_equal(stripped_a, stripped_b) + elif isinstance(a, (Quantity, np.ndarray)): + assert_array_equal(a, b) + else: + a_ = array_strip_units(a) + b_ = array_strip_units(b) + assert a_ == b_, f"different values: {a_!r} ←→ {b_!r}" + + +def assert_indexer_units_equal(a, b): + __tracebackhide__ = True + + units_a = extract_indexer_units(a) + units_b = extract_indexer_units(b) + + assert units_a == units_b, f"different units: {units_a!r} ←→ {units_b!r}" + + def assert_units_equal(a, b): __tracebackhide__ = True assert extract_units(a) == extract_units(b)