diff --git a/docs/examples/plotting.ipynb b/docs/examples/plotting.ipynb index 5a4a0c7e..3955fb0f 100644 --- a/docs/examples/plotting.ipynb +++ b/docs/examples/plotting.ipynb @@ -108,7 +108,7 @@ "metadata": {}, "outputs": [], "source": [ - "monthly_means.pint.sel(\n", + "monthly_means.sel(\n", " lat=ureg.Quantity(4350, \"angular_minute\"),\n", " lon=ureg.Quantity(12000, \"angular_minute\"),\n", ")" diff --git a/docs/whats-new.rst b/docs/whats-new.rst index a90a8423..babbf532 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -6,6 +6,8 @@ What's new ------------------ - drop support for python 3.9 (:pull:`266`) By `Justus Magin `_. +- create a `PintIndex` to allow units on indexed coordinates (:pull:`163`, :issue:`162`) + By `Justus Magin `_ and `Benoit Bovy `_. 0.4 (23 Jun 2024) ----------------- diff --git a/pint_xarray/__init__.py b/pint_xarray/__init__.py index 3ce42d86..9991351b 100644 --- a/pint_xarray/__init__.py +++ b/pint_xarray/__init__.py @@ -5,6 +5,7 @@ from . import accessors, formatting, testing # noqa: F401 from .accessors import default_registry as unit_registry from .accessors import setup_registry +from .index import PintIndex try: __version__ = version("pint-xarray") @@ -21,4 +22,5 @@ "testing", "unit_registry", "setup_registry", + "PintIndex", ] diff --git a/pint_xarray/accessors.py b/pint_xarray/accessors.py index 2ce576d0..2104021b 100644 --- a/pint_xarray/accessors.py +++ b/pint_xarray/accessors.py @@ -152,10 +152,9 @@ def __getitem__(self, indexers): raise NotImplementedError("pandas-style indexing is not supported, yet") dims = self.ds.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # convert the indexes to the indexer's units @@ -165,11 +164,13 @@ def __getitem__(self, indexers): raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } - return converted.loc[stripped_indexers] + stripped_indexers = conversion.strip_indexer_units(indexers) + + stripped = conversion.strip_units(converted) + converted_units = conversion.extract_units(converted) + indexed = stripped.loc[stripped_indexers] + + return conversion.attach_units(indexed, converted_units) class DataArrayLocIndexer: @@ -183,10 +184,9 @@ def __getitem__(self, indexers): raise NotImplementedError("pandas-style indexing is not supported, yet") dims = self.da.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # convert the indexes to the indexer's units @@ -196,11 +196,13 @@ def __getitem__(self, indexers): raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } - return converted.loc[stripped_indexers] + stripped_indexers = conversion.strip_indexer_units(indexers) + + stripped = conversion.strip_units(converted) + converted_units = conversion.extract_units(converted) + indexed = stripped.loc[stripped_indexers] + + return conversion.attach_units(indexed, converted_units) def __setitem__(self, indexers, values): if not is_dict_like(indexers): @@ -219,10 +221,7 @@ def __setitem__(self, indexers, values): raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in converted.items() - } + stripped_indexers = conversion.strip_indexer_units(converted) self.da.loc[stripped_indexers] = values @@ -252,12 +251,6 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs): the data into memory. To avoid that, consider converting to ``dask`` first (e.g. using ``chunk``). - .. warning:: - - As units in dimension coordinates are not supported until - ``xarray`` changes the way it implements indexes, these - units will be set as attributes. - .. note:: Also note that datetime units (i.e. ones that match ``{units} since {date}``) in unit attributes will be @@ -648,10 +641,9 @@ def reindex( indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") dims = self.da.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # TODO: handle tolerance @@ -659,20 +651,19 @@ def reindex( # convert the indexes to the indexer's units converted = conversion.convert_units(self.da, indexer_units) + converted_units = conversion.extract_units(converted) + stripped = conversion.strip_units(converted) # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } - indexed = converted.reindex( + stripped_indexers = conversion.strip_indexer_units(indexers) + indexed = stripped.reindex( stripped_indexers, method=method, tolerance=tolerance, copy=copy, fill_value=fill_value, ) - return indexed + return conversion.attach_units(indexed, converted_units) def reindex_like( self, other, method=None, tolerance=None, copy=True, fill_value=NA @@ -692,19 +683,24 @@ def reindex_like( xarray.DataArray.pint.reindex xarray.DataArray.reindex_like """ - indexer_units = conversion.extract_unit_attributes(other) + indexer_units = conversion.extract_units(other) + + converted = conversion.convert_units(self.da, indexer_units) + units = conversion.extract_units(converted) + stripped = conversion.strip_units(converted) + stripped_other = conversion.strip_units(other) # TODO: handle tolerance # TODO: handle fill_value - converted = conversion.convert_units(self.da, indexer_units) - return converted.reindex_like( - other, + reindexed = stripped.reindex_like( + stripped_other, method=method, tolerance=tolerance, copy=copy, fill_value=fill_value, ) + return conversion.attach_units(reindexed, units) def interp( self, @@ -731,10 +727,9 @@ def interp( indexers = either_dict_or_kwargs(coords, coords_kwargs, "interp") dims = self.da.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # convert the indexes to the indexer's units @@ -743,10 +738,7 @@ def interp( stripped = conversion.strip_units(converted) # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } + stripped_indexers = conversion.strip_indexer_units(indexers) interpolated = stripped.interp( stripped_indexers, method=method, @@ -770,13 +762,14 @@ def interp_like(self, other, method="linear", assume_sorted=False, kwargs=None): xarray.DataArray.pint.interp xarray.DataArray.interp_like """ - indexer_units = conversion.extract_unit_attributes(other) + indexer_units = conversion.extract_units(other) converted = conversion.convert_units(self.da, indexer_units) units = conversion.extract_units(converted) stripped = conversion.strip_units(converted) + stripped_other = conversion.strip_units(other) interpolated = stripped.interp_like( - other, + stripped_other, method=method, assume_sorted=assume_sorted, kwargs=kwargs, @@ -804,10 +797,9 @@ def sel( indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") dims = self.da.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # TODO: handle tolerance @@ -819,18 +811,18 @@ def sel( raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } - indexed = converted.sel( + stripped_indexers = conversion.strip_indexer_units(indexers) + + stripped = conversion.strip_units(converted) + converted_units = conversion.extract_units(converted) + indexed = stripped.sel( stripped_indexers, method=method, tolerance=tolerance, drop=drop, ) - return indexed + return conversion.attach_units(indexed, converted_units) @property def loc(self): @@ -872,10 +864,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in converted_indexers.items() - } + stripped_indexers = conversion.strip_indexer_units(converted_indexers) indexed = self.da.drop_sel( stripped_indexers, errors=errors, @@ -984,12 +973,6 @@ def quantify(self, units=_default, unit_registry=None, **unit_kwargs): the data into memory. To avoid that, consider converting to ``dask`` first (e.g. using ``chunk``). - .. warning:: - - As units in dimension coordinates are not supported until - ``xarray`` changes the way it implements indexes, these - units will be set as attributes. - .. note:: Also note that datetime units (i.e. ones that match ``{units} since {date}``) in unit attributes will be @@ -1425,10 +1408,9 @@ def reindex( indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") dims = self.ds.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # TODO: handle tolerance @@ -1436,20 +1418,19 @@ def reindex( # convert the indexes to the indexer's units converted = conversion.convert_units(self.ds, indexer_units) + converted_units = conversion.extract_units(converted) + stripped = conversion.strip_units(converted) # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } - indexed = converted.reindex( + stripped_indexers = conversion.strip_indexer_units(indexers) + indexed = stripped.reindex( stripped_indexers, method=method, tolerance=tolerance, copy=copy, fill_value=fill_value, ) - return indexed + return conversion.attach_units(indexed, converted_units) def reindex_like( self, other, method=None, tolerance=None, copy=True, fill_value=NA @@ -1469,19 +1450,24 @@ def reindex_like( xarray.Dataset.pint.reindex xarray.Dataset.reindex_like """ - indexer_units = conversion.extract_unit_attributes(other) + indexer_units = conversion.extract_units(other) + + converted = conversion.convert_units(self.ds, indexer_units) + units = conversion.extract_units(converted) + stripped = conversion.strip_units(converted) + stripped_other = conversion.strip_units(other) # TODO: handle tolerance # TODO: handle fill_value - converted = conversion.convert_units(self.ds, indexer_units) - return converted.reindex_like( - other, + reindexed = stripped.reindex_like( + stripped_other, method=method, tolerance=tolerance, copy=copy, fill_value=fill_value, ) + return conversion.attach_units(reindexed, units) def interp( self, @@ -1508,10 +1494,9 @@ def interp( indexers = either_dict_or_kwargs(coords, coords_kwargs, "interp") dims = self.ds.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # convert the indexes to the indexer's units @@ -1520,10 +1505,7 @@ def interp( stripped = conversion.strip_units(converted) # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } + stripped_indexers = conversion.strip_indexer_units(indexers) interpolated = stripped.interp( stripped_indexers, method=method, @@ -1547,13 +1529,14 @@ def interp_like(self, other, method="linear", assume_sorted=False, kwargs=None): xarray.Dataset.pint.interp xarray.Dataset.interp_like """ - indexer_units = conversion.extract_unit_attributes(other) + indexer_units = conversion.extract_units(other) converted = conversion.convert_units(self.ds, indexer_units) units = conversion.extract_units(converted) stripped = conversion.strip_units(converted) + stripped_other = conversion.strip_units(other) interpolated = stripped.interp_like( - other, + stripped_other, method=method, assume_sorted=assume_sorted, kwargs=kwargs, @@ -1581,10 +1564,9 @@ def sel( indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") dims = self.ds.dims + indexer_units = conversion.extract_indexer_units(indexers) indexer_units = { - name: conversion.extract_indexer_units(indexer) - for name, indexer in indexers.items() - if name in dims + name: indexer for name, indexer in indexer_units.items() if name in dims } # TODO: handle tolerance @@ -1596,18 +1578,18 @@ def sel( raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in indexers.items() - } - indexed = converted.sel( + stripped_indexers = conversion.strip_indexer_units(indexers) + + stripped = conversion.strip_units(converted) + converted_units = conversion.extract_units(converted) + indexed = stripped.sel( stripped_indexers, method=method, tolerance=tolerance, drop=drop, ) - return indexed + return conversion.attach_units(indexed, converted_units) @property def loc(self): @@ -1651,10 +1633,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): raise KeyError(*e.args) from e # index - stripped_indexers = { - name: conversion.strip_indexer_units(indexer) - for name, indexer in converted_indexers.items() - } + stripped_indexers = conversion.strip_indexer_units(converted_indexers) indexed = self.ds.drop_sel( stripped_indexers, errors=errors, diff --git a/pint_xarray/conversion.py b/pint_xarray/conversion.py index b7ac292c..5b801fb4 100644 --- a/pint_xarray/conversion.py +++ b/pint_xarray/conversion.py @@ -2,10 +2,11 @@ import re import pint -from xarray import DataArray, Dataset, IndexVariable, Variable +from xarray import Coordinates, DataArray, Dataset, IndexVariable, Variable from .compat import call_on_dataset from .errors import format_error_message +from .index import PintIndex no_unit_values = ("none", None) unit_attribute_name = "units" @@ -121,17 +122,38 @@ def attach_units_variable(variable, units): return new_obj -def dataset_from_variables(variables, coords, attrs): - data_vars = {name: var for name, var in variables.items() if name not in coords} - coords = {name: var for name, var in variables.items() if name in coords} +def dataset_from_variables(variables, coordinate_names, indexes, attrs): + data_vars = { + name: var for name, var in variables.items() if name not in coordinate_names + } + coords = {name: var for name, var in variables.items() if name in coordinate_names} + + new_coords = Coordinates(coords, indexes=indexes) + return Dataset(data_vars=data_vars, coords=new_coords, attrs=attrs) + + +def attach_units_index(index, index_vars, units): + if all(unit is None for unit in units.values()): + # skip non-quantity indexed variables + return index + + if isinstance(index, PintIndex) and index.units != units: + raise ValueError( + f"cannot attach units to quantified index: {index.units} != {units}" + ) - return Dataset(data_vars=data_vars, coords=coords, attrs=attrs) + return PintIndex(index=index, units=units) def attach_units_dataset(obj, units): attached = {} rejected_vars = {} + + indexed_variables = obj.xindexes.variables for name, var in obj.variables.items(): + if name in indexed_variables: + continue + unit = units.get(name) try: converted = attach_units_variable(var, unit) @@ -139,10 +161,23 @@ def attach_units_dataset(obj, units): except ValueError as e: rejected_vars[name] = (unit, e) + indexes, index_vars = obj.xindexes.copy_indexes() + for idx, idx_vars in obj.xindexes.group_by_index(): + idx_units = {name: units.get(name) for name in idx_vars.keys()} + try: + attached_idx = attach_units_index(idx, idx_vars, idx_units) + indexes.update({k: attached_idx for k in idx_vars}) + index_vars.update(attached_idx.create_variables(idx_vars)) + except ValueError as e: + rejected_vars[name] = (units, e) + + attached.update(index_vars) + if rejected_vars: raise ValueError(rejected_vars) - return dataset_from_variables(attached, obj._coord_names, obj.attrs) + reordered = {name: attached[name] for name in obj.variables.keys()} + return dataset_from_variables(reordered, obj._coord_names, indexes, obj.attrs) def attach_units(obj, units): @@ -215,20 +250,64 @@ def convert_units_variable(variable, units): return new_obj +def convert_units_index(index, index_vars, units): + if not isinstance(index, PintIndex): + raise ValueError("cannot convert non-quantified index") + + converted_vars = {} + failed = {} + for name, var in index_vars.items(): + unit = units.get(name) + try: + converted = convert_units_variable(var, unit) + converted_vars[name] = strip_units_variable(converted) + except (ValueError, pint.errors.PintTypeError) as e: + failed[name] = e + + if failed: + # raise exception group + raise ValueError("failed to convert index variables:", failed) + + # TODO: figure out how to pull out `options` + converted_index = index.index.from_variables(converted_vars, options={}) + return PintIndex(index=converted_index, units=units) + + def convert_units_dataset(obj, units): converted = {} failed = {} + indexed_variables = obj.xindexes.variables for name, var in obj.variables.items(): + if name in indexed_variables: + continue + unit = units.get(name) try: converted[name] = convert_units_variable(var, unit) except (ValueError, pint.errors.PintTypeError) as e: failed[name] = e + indexes, index_vars = obj.xindexes.copy_indexes() + for idx, idx_vars in obj.xindexes.group_by_index(): + idx_units = {name: units.get(name) for name in idx_vars.keys()} + if all(unit is None for unit in idx_units.values()): + continue + + try: + converted_index = convert_units_index(idx, idx_vars, idx_units) + indexes.update({k: converted_index for k in idx_vars}) + index_vars.update(converted_index.create_variables()) + except (ValueError, pint.errors.PintTypeError) as e: + names = tuple(idx_vars) + failed[names] = e + + converted.update(index_vars) + if failed: raise ValueError(failed) - return dataset_from_variables(converted, obj._coord_names, obj.attrs) + reordered = {name: converted[name] for name in obj.variables.keys()} + return dataset_from_variables(reordered, obj._coord_names, indexes, obj.attrs) def convert_units(obj, units): @@ -308,7 +387,12 @@ def strip_units_variable(var): def strip_units_dataset(obj): variables = {name: strip_units_variable(var) for name, var in obj.variables.items()} - return dataset_from_variables(variables, obj._coord_names, obj.attrs) + indexes = { + name: (index.index if isinstance(index, PintIndex) else index) + for name, index in obj.xindexes.items() + } + + return dataset_from_variables(variables, obj._coord_names, indexes, obj.attrs) def strip_units(obj): @@ -403,25 +487,31 @@ def convert(indexer, units): return converted -def extract_indexer_units(indexer): - if isinstance(indexer, slice): - return slice_extract_units(indexer) - elif isinstance(indexer, (DataArray, Variable)): - return array_extract_units(indexer.data) - else: - return array_extract_units(indexer) +def extract_indexer_units(indexers): + def extract(indexer): + if isinstance(indexer, slice): + return slice_extract_units(indexer) + elif isinstance(indexer, (DataArray, Variable)): + return array_extract_units(indexer.data) + else: + return array_extract_units(indexer) + return {name: extract(indexer) for name, indexer in indexers.items()} -def strip_indexer_units(indexer): - if isinstance(indexer, slice): - return slice( - array_strip_units(indexer.start), - array_strip_units(indexer.stop), - array_strip_units(indexer.step), - ) - elif isinstance(indexer, DataArray): - return strip_units(indexer) - elif isinstance(indexer, Variable): - return strip_units_variable(indexer) - else: - return array_strip_units(indexer) + +def strip_indexer_units(indexers): + def strip(indexer): + if isinstance(indexer, slice): + return slice( + array_strip_units(indexer.start), + array_strip_units(indexer.stop), + array_strip_units(indexer.step), + ) + elif isinstance(indexer, DataArray): + return strip_units(indexer) + elif isinstance(indexer, Variable): + return strip_units_variable(indexer) + else: + return array_strip_units(indexer) + + return {name: strip(indexer) for name, indexer in indexers.items()} diff --git a/pint_xarray/index.py b/pint_xarray/index.py new file mode 100644 index 00000000..5d13ff1f --- /dev/null +++ b/pint_xarray/index.py @@ -0,0 +1,95 @@ +from xarray import Variable +from xarray.core.indexes import Index, PandasIndex + +from . import conversion + + +class PintIndex(Index): + def __init__(self, *, index, units): + """create a unit-aware MetaIndex + + Parameters + ---------- + index : xarray.Index + The wrapped index object. + units : mapping of hashable to unit-like + The units of the indexed coordinates + """ + self.index = index + self.units = units + + def _replace(self, new_index): + return self.__class__(index=new_index, units=self.units) + + def create_variables(self, variables=None): + index_vars = self.index.create_variables(variables) + + index_vars_units = {} + for name, var in index_vars.items(): + data = conversion.array_attach_units(var.data, self.units[name]) + var_units = Variable(var.dims, data, attrs=var.attrs, encoding=var.encoding) + index_vars_units[name] = var_units + + return index_vars_units + + @classmethod + def from_variables(cls, variables, options): + if len(variables) != 1: + raise ValueError("can only create a default index from single variables") + + units = options.pop("units", None) + index = PandasIndex.from_variables(variables, options=options) + return cls(index=index, units={index.index.name: units}) + + @classmethod + def concat(cls, indexes, dim, positions): + raise NotImplementedError() + + @classmethod + def stack(cls, variables, dim): + raise NotImplementedError() + + def unstack(self): + raise NotImplementedError() + + def sel(self, labels): + converted_labels = conversion.convert_indexer_units(labels, self.units) + stripped_labels = conversion.strip_indexer_units(converted_labels) + + return self.index.sel(stripped_labels) + + def isel(self, indexers): + subset = self.index.isel(indexers) + if subset is None: + return None + + return self._replace(subset) + + def join(self, other, how="inner"): + raise NotImplementedError() + + def reindex_like(self, other): + raise NotImplementedError() + + def equals(self, other): + if not isinstance(other, PintIndex): + return False + + # for now we require exactly matching units to avoid the potentially expensive conversion + if self.units != other.units: + return False + + # last to avoid the potentially expensive comparison + return self.index.equals(other.index) + + def roll(self, shifts): + return self._replace(self.index.roll(shifts)) + + def rename(self, name_dict, dims_dict): + return self._replace(self.index.rename(name_dict, dims_dict)) + + def __getitem__(self, indexer): + return self._replace(self.index[indexer]) + + def _repr_inline_(self, max_width): + return f"{self.__class__.__name__}({self.index.__class__.__name__})" diff --git a/pint_xarray/tests/test_accessors.py b/pint_xarray/tests/test_accessors.py index 46d4ed61..b33c7866 100644 --- a/pint_xarray/tests/test_accessors.py +++ b/pint_xarray/tests/test_accessors.py @@ -7,6 +7,7 @@ from pint import Unit, UnitRegistry from .. import accessors, conversion +from ..index import PintIndex from .utils import ( assert_equal, assert_identical, @@ -22,7 +23,8 @@ # make sure scalars are converted to 0d arrays so quantities can # always be treated like ndarrays -unit_registry = UnitRegistry(force_ndarray=True) +from pint_xarray import unit_registry + Quantity = unit_registry.Quantity nan = np.nan @@ -159,7 +161,17 @@ def test_dimension_coordinate_array_already_quantified(self): arr.pint.quantify({"x": "s"}) def test_dimension_coordinate_array_already_quantified_same_units(self): - ds = xr.Dataset(coords={"x": ("x", [10], {"units": unit_registry.Unit("m")})}) + x = unit_registry.Quantity([10], "m") + coords = xr.Coordinates( + {"x": x}, + indexes={ + "x": PintIndex.from_variables( + {"x": xr.Variable("x", x.magnitude)}, + options={"units": x.units}, + ), + }, + ) + ds = xr.Dataset(coords=coords) arr = ds.x quantified = arr.pint.quantify({"x": "m"}) @@ -482,18 +494,20 @@ def test_roundtrip_data(self, example_unitless_ds): id="Dataset-incompatible units-data", ), pytest.param( - xr.Dataset(coords={"x": ("x", [2, 4], {"units": Unit("s")})}), + xr.Dataset(coords=xr.Coordinates({"x": Quantity([2, 4], "s")}, indexes={})), {"x": "ms"}, - xr.Dataset(coords={"x": ("x", [2000, 4000], {"units": Unit("ms")})}), + xr.Dataset( + coords=xr.Coordinates({"x": Quantity([2000, 4000], "ms")}, indexes={}) + ), None, - id="Dataset-compatible units-dims", + id="Dataset-compatible units-dims-no index", ), pytest.param( - xr.Dataset(coords={"x": ("x", [2, 4], {"units": Unit("s")})}), + xr.Dataset(coords=xr.Coordinates({"x": Quantity([2, 4], "s")}, indexes={})), {"x": "mm"}, None, ValueError, - id="Dataset-incompatible units-dims", + id="Dataset-incompatible units-dims-no index", ), pytest.param( xr.DataArray(Quantity([0, 1], "m"), dims="x"), @@ -525,25 +539,29 @@ def test_roundtrip_data(self, example_unitless_ds): ), pytest.param( xr.DataArray( - [0, 1], dims="x", coords={"x": ("x", [2, 4], {"units": Unit("s")})} + [0, 1], + dims="x", + coords=xr.Coordinates({"x": Quantity([2, 4], "s")}, indexes={}), ), {"x": "ms"}, xr.DataArray( [0, 1], dims="x", - coords={"x": ("x", [2000, 4000], {"units": Unit("ms")})}, + coords=xr.Coordinates({"x": Quantity([2000, 4000], "ms")}, indexes={}), ), None, - id="DataArray-compatible units-dims", + id="DataArray-compatible units-dims-no index", ), pytest.param( xr.DataArray( - [0, 1], dims="x", coords={"x": ("x", [2, 4], {"units": Unit("s")})} + [0, 1], + dims="x", + coords=xr.Coordinates({"x": Quantity([2, 4], "s")}, indexes={}), ), {"x": "mm"}, None, ValueError, - id="DataArray-incompatible units-dims", + id="DataArray-incompatible units-dims-no index", ), ), ) @@ -666,13 +684,17 @@ def test_to(obj, units, expected, error): ), ) def test_sel(obj, indexers, expected, error): + obj_ = obj.pint.quantify() + if error is not None: with pytest.raises(error): - obj.pint.sel(indexers) + obj_.pint.sel(indexers) else: - actual = obj.pint.sel(indexers) - assert_units_equal(actual, expected) - assert_identical(actual, expected) + expected_ = expected.pint.quantify() + + actual = obj_.pint.sel(indexers) + assert_units_equal(actual, expected_) + assert_identical(actual, expected_) @pytest.mark.parametrize( @@ -783,13 +805,17 @@ def test_sel(obj, indexers, expected, error): ), ) def test_loc(obj, indexers, expected, error): + obj_ = obj.pint.quantify() + if error is not None: with pytest.raises(error): - obj.pint.loc[indexers] + obj_.pint.loc[indexers] else: - actual = obj.pint.loc[indexers] - assert_units_equal(actual, expected) - assert_identical(actual, expected) + expected_ = expected.pint.quantify() + + actual = obj_.pint.loc[indexers] + assert_units_equal(actual, expected_) + assert_identical(actual, expected_) @pytest.mark.parametrize( @@ -1136,72 +1162,73 @@ def test_chunk(obj): @pytest.mark.parametrize( - ["obj", "indexers", "expected", "error"], + ["obj", "units", "indexers", "expected", "expected_units", "error"], ( pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, {"x": Quantity([10, 30, 50], "dm"), "y": Quantity([0, 120, 240], "s")}, - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 120, 240], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 120, 240])}), + {"x": "dm", "y": "s"}, None, id="Dataset-identical units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, {"x": Quantity([0, 1, 3, 5], "m"), "y": Quantity([0, 2, 4], "min")}, - xr.Dataset( - { - "x": ("x", [0, 1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2, 4], {"units": unit_registry.Unit("min")}), - } - ), + xr.Dataset({"x": ("x", [0, 1, 3, 5]), "y": ("y", [0, 2, 4])}), + {"x": "m", "y": "min"}, None, id="Dataset-compatible units", ), + pytest.param( + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + {"x": Quantity([1, 3], "s"), "y": Quantity([1], "m")}, + None, + {}, + ValueError, + id="Dataset-incompatible units", + ), pytest.param( xr.Dataset( { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), + "a": (("x", "y"), np.array([[0, 1], [2, 3], [4, 5]])), + "x": [10, 20, 30], + "y": [60, 120], } ), - {"x": Quantity([1, 3], "s"), "y": Quantity([1], "m")}, + {"a": "kg"}, + { + "x": [15, 25], + "y": [75, 105], + }, + xr.Dataset( + { + "a": (("x", "y"), np.array([[np.nan, np.nan], [np.nan, np.nan]])), + "x": [15, 25], + "y": [75, 105], + } + ), + {"a": "kg"}, None, - ValueError, - id="Dataset-incompatible units", + id="Dataset-data units", ), pytest.param( xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, {"x": Quantity([10, 30, 50], "dm"), "y": Quantity([0, 240], "s")}, xr.DataArray( [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 240], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 30, 50]), "y": ("y", [0, 240])}, ), + {"x": "dm", "y": "s"}, None, id="DataArray-identical units", ), @@ -1209,20 +1236,16 @@ def test_chunk(obj): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, {"x": Quantity([1, 3, 5], "m"), "y": Quantity([0, 2], "min")}, xr.DataArray( [[np.nan, 1], [np.nan, 5], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2], {"units": unit_registry.Unit("min")}), - }, + coords={"x": ("x", [1, 3, 5]), "y": ("y", [0, 2])}, ), + {"x": "m", "y": "min"}, None, id="DataArray-compatible units", ), @@ -1230,89 +1253,78 @@ def test_chunk(obj): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, {"x": Quantity([10, 30], "s"), "y": Quantity([60], "m")}, None, + {}, ValueError, id="DataArray-incompatible units", ), + pytest.param( + xr.DataArray( + np.array([[0, 1], [2, 3], [4, 5]]), + dims=("x", "y"), + coords={"x": [10, 20, 30], "y": [60, 120]}, + ), + {None: "kg"}, + {"x": [15, 25], "y": [75, 105]}, + xr.DataArray( + [[np.nan, np.nan], [np.nan, np.nan]], + dims=("x", "y"), + coords={"x": [15, 25], "y": [75, 105]}, + ), + {None: "kg"}, + None, + id="DataArray-data units", + ), ), ) -def test_reindex(obj, indexers, expected, error): +def test_reindex(obj, units, indexers, expected, expected_units, error): + obj_ = obj.pint.quantify(units) + if error is not None: with pytest.raises(error): obj.pint.reindex(indexers) else: - actual = obj.pint.reindex(indexers) - assert_units_equal(actual, expected) - assert_identical(actual, expected) + expected_ = expected.pint.quantify(expected_units) + + actual = obj_.pint.reindex(indexers) + assert_units_equal(actual, expected_) + assert_identical(actual, expected_) @pytest.mark.parametrize( - ["obj", "other", "expected", "error"], + ["obj", "units", "other", "other_units", "expected", "expected_units", "error"], ( pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 120, 240], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 120, 240], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 120, 240])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 120, 240])}), + {"x": "dm", "y": "s"}, None, id="Dataset-identical units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [0, 1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2, 4], {"units": unit_registry.Unit("min")}), - } - ), - xr.Dataset( - { - "x": ("x", [0, 1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2, 4], {"units": unit_registry.Unit("min")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [0, 1, 3, 5]), "y": ("y", [0, 2, 4])}), + {"x": "m", "y": "min"}, + xr.Dataset({"x": ("x", [0, 1, 3, 5]), "y": ("y", [0, 2, 4])}), + {"x": "m", "y": "min"}, None, id="Dataset-compatible units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [1, 3], {"units": unit_registry.Unit("s")}), - "y": ("y", [1], {"units": unit_registry.Unit("m")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [1, 3]), "y": ("y", [1])}), + {"x": "s", "y": "m"}, None, + {}, ValueError, id="Dataset-incompatible units", ), @@ -1320,51 +1332,57 @@ def test_reindex(obj, indexers, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, - ), - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 240], {"units": unit_registry.Unit("s")}), - } + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 240])}), + {"x": "dm", "y": "s"}, xr.DataArray( [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 240], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 30, 50]), "y": ("y", [0, 240])}, ), + {"x": "dm", "y": "s"}, None, id="DataArray-identical units", ), pytest.param( - xr.DataArray( - [[0, 1], [2, 3], [4, 5]], - dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + xr.Dataset( + { + "a": (("x", "y"), [[0, 1], [2, 3], [4, 5]]), + "x": [10, 20, 30], + "y": [60, 120], + } ), + {"a": "kg"}, + xr.Dataset({"x": [15, 25], "y": [75, 105]}), + {}, xr.Dataset( { - "x": ("x", [1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2], {"units": unit_registry.Unit("min")}), + "a": (("x", "y"), [[np.nan, np.nan], [np.nan, np.nan]]), + "x": [15, 25], + "y": [75, 105], } ), + {"a": "kg"}, + None, + id="Dataset-data units", + ), + pytest.param( + xr.DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, + ), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [1, 3, 5]), "y": ("y", [0, 2])}), + {"x": "m", "y": "min"}, xr.DataArray( [[np.nan, 1], [np.nan, 5], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2], {"units": unit_registry.Unit("min")}), - }, + coords={"x": ("x", [1, 3, 5]), "y": ("y", [0, 2])}, ), + {"x": "m", "y": "min"}, None, id="DataArray-compatible units", ), @@ -1372,102 +1390,103 @@ def test_reindex(obj, indexers, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, - ), - xr.Dataset( - { - "x": ("x", [10, 30], {"units": unit_registry.Unit("s")}), - "y": ("y", [60], {"units": unit_registry.Unit("m")}), - } + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30]), "y": ("y", [60])}), + {"x": "s", "y": "m"}, None, + {}, ValueError, id="DataArray-incompatible units", ), + pytest.param( + xr.DataArray( + [[0, 1], [2, 3], [4, 5]], + dims=("x", "y"), + coords={"x": [10, 20, 30], "y": [60, 120]}, + ), + {"a": "kg"}, + xr.Dataset({"x": [15, 25], "y": [75, 105]}), + {}, + xr.DataArray( + [[np.nan, np.nan], [np.nan, np.nan]], + dims=("x", "y"), + coords={"x": [15, 25], "y": [75, 105]}, + ), + {"a": "kg"}, + None, + id="DataArray-data units", + ), ), ) -def test_reindex_like(obj, other, expected, error): +def test_reindex_like(obj, units, other, other_units, expected, expected_units, error): + obj_ = obj.pint.quantify(units) + other_ = other.pint.quantify(other_units) + if error is not None: with pytest.raises(error): - obj.pint.reindex_like(other) + obj_.pint.reindex_like(other_) else: - actual = obj.pint.reindex_like(other) - assert_units_equal(actual, expected) - assert_identical(actual, expected) + expected_ = expected.pint.quantify(expected_units) + + actual = obj_.pint.reindex_like(other_) + assert_units_equal(actual, expected_) + assert_identical(actual, expected_) @requires_scipy @pytest.mark.parametrize( - ["obj", "indexers", "expected", "error"], + ["obj", "units", "indexers", "expected", "expected_units", "error"], ( pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, {"x": Quantity([10, 30, 50], "dm"), "y": Quantity([0, 120, 240], "s")}, - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 120, 240], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 120, 240])}), + {"x": "dm", "y": "s"}, None, id="Dataset-identical units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, {"x": Quantity([0, 1, 3, 5], "m"), "y": Quantity([0, 2, 4], "min")}, - xr.Dataset( - { - "x": ("x", [0, 1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2, 4], {"units": unit_registry.Unit("min")}), - } - ), + xr.Dataset({"x": ("x", [0, 1, 3, 5]), "y": ("y", [0, 2, 4])}), + {"x": "m", "y": "min"}, None, id="Dataset-compatible units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, {"x": Quantity([1, 3], "s"), "y": Quantity([1], "m")}, None, + {}, ValueError, id="Dataset-incompatible units", ), pytest.param( xr.Dataset( { - "a": (("x", "y"), Quantity([[0, 1], [2, 3], [4, 5]], "kg")), + "a": (("x", "y"), np.array([[0, 1], [2, 3], [4, 5]])), "x": [10, 20, 30], "y": [60, 120], } ), + {"a": "kg"}, { "x": [15, 25], "y": [75, 105], }, xr.Dataset( { - "a": (("x", "y"), Quantity([[1.25, 1.75], [3.25, 3.75]], "kg")), + "a": (("x", "y"), np.array([[1.25, 1.75], [3.25, 3.75]])), "x": [15, 25], "y": [75, 105], } ), + {"a": "kg"}, None, id="Dataset-data units", ), @@ -1475,20 +1494,16 @@ def test_reindex_like(obj, other, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, {"x": Quantity([10, 30, 50], "dm"), "y": Quantity([0, 240], "s")}, xr.DataArray( [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 240], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 30, 50]), "y": ("y", [0, 240])}, ), + {"x": "dm", "y": "s"}, None, id="DataArray-identical units", ), @@ -1496,20 +1511,16 @@ def test_reindex_like(obj, other, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, {"x": Quantity([1, 3, 5], "m"), "y": Quantity([0, 2], "min")}, xr.DataArray( [[np.nan, 1], [np.nan, 5], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2], {"units": unit_registry.Unit("min")}), - }, + coords={"x": ("x", [1, 3, 5]), "y": ("y", [0, 2])}, ), + {"x": "m", "y": "min"}, None, id="DataArray-compatible units", ), @@ -1517,114 +1528,79 @@ def test_reindex_like(obj, other, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, {"x": Quantity([10, 30], "s"), "y": Quantity([60], "m")}, None, + {}, ValueError, id="DataArray-incompatible units", ), pytest.param( xr.DataArray( - Quantity([[0, 1], [2, 3], [4, 5]], "kg"), + np.array([[0, 1], [2, 3], [4, 5]]), dims=("x", "y"), - coords={ - "x": [10, 20, 30], - "y": [60, 120], - }, + coords={"x": [10, 20, 30], "y": [60, 120]}, ), - { - "x": [15, 25], - "y": [75, 105], - }, + {None: "kg"}, + {"x": [15, 25], "y": [75, 105]}, xr.DataArray( - Quantity([[1.25, 1.75], [3.25, 3.75]], "kg"), + [[1.25, 1.75], [3.25, 3.75]], dims=("x", "y"), - coords={ - "x": [15, 25], - "y": [75, 105], - }, + coords={"x": [15, 25], "y": [75, 105]}, ), + {None: "kg"}, None, id="DataArray-data units", ), ), ) -def test_interp(obj, indexers, expected, error): +def test_interp(obj, units, indexers, expected, expected_units, error): + obj_ = obj.pint.quantify(units) + if error is not None: with pytest.raises(error): obj.pint.interp(indexers) else: - actual = obj.pint.interp(indexers) - assert_units_equal(actual, expected) - assert_identical(actual, expected) + expected_ = expected.pint.quantify(expected_units) + + actual = obj_.pint.interp(indexers) + assert_units_equal(actual, expected_) + assert_identical(actual, expected_) @requires_scipy @pytest.mark.parametrize( - ["obj", "other", "expected", "error"], + ["obj", "units", "other", "other_units", "expected", "expected_units", "error"], ( pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 120, 240], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 120, 240], {"units": unit_registry.Unit("s")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 120, 240])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 120, 240])}), + {"x": "dm", "y": "s"}, None, id="Dataset-identical units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [0, 1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2, 4], {"units": unit_registry.Unit("min")}), - } - ), - xr.Dataset( - { - "x": ("x", [0, 1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2, 4], {"units": unit_registry.Unit("min")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [0, 1, 3, 5]), "y": ("y", [0, 2, 4])}), + {"x": "m", "y": "min"}, + xr.Dataset({"x": ("x", [0, 1, 3, 5]), "y": ("y", [0, 2, 4])}), + {"x": "m", "y": "min"}, None, id="Dataset-compatible units", ), pytest.param( - xr.Dataset( - { - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - } - ), - xr.Dataset( - { - "x": ("x", [1, 3], {"units": unit_registry.Unit("s")}), - "y": ("y", [1], {"units": unit_registry.Unit("m")}), - } - ), + xr.Dataset({"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [1, 3]), "y": ("y", [1])}), + {"x": "s", "y": "m"}, None, + {}, ValueError, id="Dataset-incompatible units", ), @@ -1632,49 +1608,39 @@ def test_interp(obj, indexers, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, - ), - xr.Dataset( - { - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 240], {"units": unit_registry.Unit("s")}), - } + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30, 50]), "y": ("y", [0, 240])}), + {"x": "dm", "y": "s"}, xr.DataArray( [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [10, 30, 50], {"units": unit_registry.Unit("dm")}), - "y": ("y", [0, 240], {"units": unit_registry.Unit("s")}), - }, + coords={"x": ("x", [10, 30, 50]), "y": ("y", [0, 240])}, ), + {"x": "dm", "y": "s"}, None, id="DataArray-identical units", ), pytest.param( xr.Dataset( { - "a": (("x", "y"), Quantity([[0, 1], [2, 3], [4, 5]], "kg")), + "a": (("x", "y"), [[0, 1], [2, 3], [4, 5]]), "x": [10, 20, 30], "y": [60, 120], } ), + {"a": "kg"}, + xr.Dataset({"x": [15, 25], "y": [75, 105]}), + {}, xr.Dataset( { + "a": (("x", "y"), [[1.25, 1.75], [3.25, 3.75]]), "x": [15, 25], "y": [75, 105], } ), - xr.Dataset( - { - "a": (("x", "y"), Quantity([[1.25, 1.75], [3.25, 3.75]], "kg")), - "x": [15, 25], - "y": [75, 105], - } - ), + {"a": "kg"}, None, id="Dataset-data units", ), @@ -1682,25 +1648,17 @@ def test_interp(obj, indexers, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, - ), - xr.Dataset( - { - "x": ("x", [1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2], {"units": unit_registry.Unit("min")}), - } + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [1, 3, 5]), "y": ("y", [0, 2])}), + {"x": "m", "y": "min"}, xr.DataArray( [[np.nan, 1], [np.nan, 5], [np.nan, np.nan]], dims=("x", "y"), - coords={ - "x": ("x", [1, 3, 5], {"units": unit_registry.Unit("m")}), - "y": ("y", [0, 2], {"units": unit_registry.Unit("min")}), - }, + coords={"x": ("x", [1, 3, 5]), "y": ("y", [0, 2])}, ), + {"x": "m", "y": "min"}, None, id="DataArray-compatible units", ), @@ -1708,57 +1666,49 @@ def test_interp(obj, indexers, expected, error): xr.DataArray( [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": ("x", [10, 20, 30], {"units": unit_registry.Unit("dm")}), - "y": ("y", [60, 120], {"units": unit_registry.Unit("s")}), - }, - ), - xr.Dataset( - { - "x": ("x", [10, 30], {"units": unit_registry.Unit("s")}), - "y": ("y", [60], {"units": unit_registry.Unit("m")}), - } + coords={"x": ("x", [10, 20, 30]), "y": ("y", [60, 120])}, ), + {"x": "dm", "y": "s"}, + xr.Dataset({"x": ("x", [10, 30]), "y": ("y", [60])}), + {"x": "s", "y": "m"}, None, + {}, ValueError, id="DataArray-incompatible units", ), pytest.param( xr.DataArray( - Quantity([[0, 1], [2, 3], [4, 5]], "kg"), + [[0, 1], [2, 3], [4, 5]], dims=("x", "y"), - coords={ - "x": [10, 20, 30], - "y": [60, 120], - }, - ), - xr.Dataset( - { - "x": [15, 25], - "y": [75, 105], - } + coords={"x": [10, 20, 30], "y": [60, 120]}, ), + {"a": "kg"}, + xr.Dataset({"x": [15, 25], "y": [75, 105]}), + {}, xr.DataArray( - Quantity([[1.25, 1.75], [3.25, 3.75]], "kg"), + [[1.25, 1.75], [3.25, 3.75]], dims=("x", "y"), - coords={ - "x": [15, 25], - "y": [75, 105], - }, + coords={"x": [15, 25], "y": [75, 105]}, ), + {"a": "kg"}, None, id="DataArray-data units", ), ), ) -def test_interp_like(obj, other, expected, error): +def test_interp_like(obj, units, other, other_units, expected, expected_units, error): + obj_ = obj.pint.quantify(units) + other_ = other.pint.quantify(other_units) + if error is not None: with pytest.raises(error): - obj.pint.interp_like(other) + obj_.pint.interp_like(other_) else: - actual = obj.pint.interp_like(other) - assert_units_equal(actual, expected) - assert_identical(actual, expected) + expected_ = expected.pint.quantify(expected_units) + + actual = obj_.pint.interp_like(other_) + assert_units_equal(actual, expected_) + assert_identical(actual, expected_) @requires_bottleneck diff --git a/pint_xarray/tests/test_conversion.py b/pint_xarray/tests/test_conversion.py index c4dd4585..5101eac1 100644 --- a/pint_xarray/tests/test_conversion.py +++ b/pint_xarray/tests/test_conversion.py @@ -1,16 +1,19 @@ import numpy as np +import pandas as pd import pint import pytest -from xarray import DataArray, Dataset, Variable +from xarray import Coordinates, DataArray, Dataset, Variable +from xarray.core.indexes import PandasIndex from pint_xarray import conversion +from pint_xarray.index import PintIndex from .utils import ( assert_array_equal, assert_array_units_equal, assert_identical, - assert_indexer_equal, assert_indexer_units_equal, + assert_indexers_equal, ) unit_registry = pint.UnitRegistry() @@ -245,17 +248,22 @@ def test_attach_units(self, type, units): q_a = to_quantity(a, units.get("a")) q_b = to_quantity(b, units.get("b")) + q_x = to_quantity(x, units.get("x")) q_u = to_quantity(u, units.get("u")) - units_x = units.get("x") + index = PandasIndex(x, dim="x") + if units.get("x") is not None: + index = PintIndex(index=index, units=units.get("x")) obj = Dataset({"a": ("x", a), "b": ("x", b)}, coords={"u": ("x", u), "x": x}) + coords = Coordinates( + coords={"u": Variable("x", q_u), "x": Variable("x", q_x)}, + indexes={"x": index}, + ) expected = Dataset( {"a": ("x", q_a), "b": ("x", q_b)}, - coords={"u": ("x", q_u), "x": x}, + coords=coords, ) - if units_x is not None: - expected.x.attrs["units"] = units_x if type == "DataArray": obj = obj["a"] @@ -264,6 +272,12 @@ def test_attach_units(self, type, units): actual = conversion.attach_units(obj, units) assert_identical(actual, expected) + if units.get("x") is None: + assert not isinstance(actual.xindexes["x"], PintIndex) + else: + assert isinstance(actual.xindexes["x"], PintIndex) + assert actual.xindexes["x"].units == {"x": units.get("x")} + @pytest.mark.parametrize("type", ("DataArray", "Dataset")) def test_attach_unit_attributes(self, type): units = {"a": "K", "b": "hPa", "u": "m", "x": "s"} @@ -372,15 +386,19 @@ def test_convert_units(self, type, variant, units, error, match): q_u = to_quantity(u, original_units.get("u")) q_x = to_quantity(x, original_units.get("x")) + x_index = PandasIndex(pd.Index(x), "x") + if original_units.get("x") is not None: + x_index = PintIndex(index=x_index, units={"x": original_units.get("x")}) + obj = Dataset( { "a": ("x", q_a), "b": ("x", q_b), }, - coords={ - "u": ("x", q_u), - "x": ("x", x, {"units": original_units.get("x")}), - }, + coords=Coordinates( + {"u": ("x", q_u), "x": ("x", q_x)}, + indexes={"x": x_index}, + ), ) if type == "DataArray": obj = obj["a"] @@ -394,20 +412,22 @@ def test_convert_units(self, type, variant, units, error, match): expected_a = convert_quantity(q_a, units.get("a", original_units.get("a"))) expected_b = convert_quantity(q_b, units.get("b", original_units.get("b"))) expected_u = convert_quantity(q_u, units.get("u", original_units.get("u"))) - expected_x = strip_quantity(convert_quantity(q_x, units.get("x"))) + expected_x = convert_quantity(q_x, units.get("x")) + expected_index = PandasIndex(pd.Index(strip_quantity(expected_x)), "x") + if units.get("x") is not None: + expected_index = PintIndex( + index=expected_index, units={"x": units.get("x")} + ) + expected = Dataset( { "a": ("x", expected_a), "b": ("x", expected_b), }, - coords={ - "u": ("x", expected_u), - "x": ( - "x", - expected_x, - {"units": units.get("x", original_units.get("x"))}, - ), - }, + coords=Coordinates( + {"u": ("x", expected_u), "x": ("x", expected_x)}, + indexes={"x": expected_index}, + ), ) if type == "DataArray": @@ -416,7 +436,7 @@ def test_convert_units(self, type, variant, units, error, match): actual = conversion.convert_units(obj, units) assert conversion.extract_units(actual) == conversion.extract_units(expected) - assert_identical(expected, actual) + assert_identical(actual, expected) @pytest.mark.parametrize( "units", @@ -436,15 +456,22 @@ def test_extract_units(self, type, units): u = np.linspace(0, 100, 2) x = np.arange(2) + index = PandasIndex(x, "x") + if units.get("x") is not None: + index = PintIndex(index=index, units={"x": units.get("x")}) + obj = Dataset( { "a": ("x", to_quantity(a, units.get("a"))), "b": ("x", to_quantity(b, units.get("b"))), }, - coords={ - "u": ("x", to_quantity(u, units.get("u"))), - "x": ("x", x, {"units": units.get("x")}), - }, + coords=Coordinates( + { + "u": ("x", to_quantity(u, units.get("u"))), + "x": ("x", to_quantity(x, units.get("x"))), + }, + indexes={"x": index}, + ), ) if type == "DataArray": obj = obj["a"] @@ -499,21 +526,33 @@ def test_extract_unit_attributes(self, obj, expected): pytest.param( DataArray( dims="x", - data=[0, 4, 3] * unit_registry.m, - coords={"u": ("x", [2, 3, 4] * unit_registry.s)}, + data=Quantity([0, 4, 3], "kg"), + coords=Coordinates( + { + "u": ("x", Quantity([2, 3, 4], "s")), + "x": Quantity([0, 1, 2], "m"), + }, + indexes={}, + ), ), - {None: None, "u": None}, + {None: None, "u": None, "x": None}, id="DataArray", ), pytest.param( Dataset( data_vars={ - "a": ("x", [3, 2, 5] * unit_registry.Pa), - "b": ("x", [0, 2, -1] * unit_registry.kg), + "a": ("x", Quantity([3, 2, 5], "Pa")), + "b": ("x", Quantity([0, 2, -1], "kg")), }, - coords={"u": ("x", [2, 3, 4] * unit_registry.s)}, + coords=Coordinates( + { + "u": ("x", Quantity([2, 3, 4], "s")), + "x": Quantity([0, 1, 2], "m"), + }, + indexes={}, + ), ), - {"a": None, "b": None, "u": None}, + {"a": None, "b": None, "u": None, "x": None}, id="Dataset", ), ), @@ -694,100 +733,118 @@ def test_convert_indexer_units(self, indexers, units, expected, error, match): conversion.convert_indexer_units(indexers, units) else: actual = conversion.convert_indexer_units(indexers, units) - assert_indexer_equal(actual["x"], expected["x"]) - assert_indexer_units_equal(actual["x"], expected["x"]) + assert_indexers_equal(actual, expected) + assert_indexer_units_equal(actual, expected) @pytest.mark.parametrize( - ["indexer", "expected"], + ["indexers", "expected"], ( - pytest.param(1, None, id="scalar-no units"), - pytest.param(Quantity(1, "m"), Unit("m"), id="scalar-units"), - pytest.param(np.array([1, 2]), None, id="array-no units"), - pytest.param(Quantity([1, 2], "s"), Unit("s"), id="array-units"), - pytest.param(Variable("x", [1, 2]), None, id="Variable-no units"), + pytest.param({"x": 1}, {"x": None}, id="scalar-no units"), + pytest.param({"x": Quantity(1, "m")}, {"x": Unit("m")}, id="scalar-units"), + pytest.param({"x": np.array([1, 2])}, {"x": None}, id="array-no units"), + pytest.param( + {"x": Quantity([1, 2], "s")}, {"x": Unit("s")}, id="array-units" + ), pytest.param( - Variable("x", Quantity([1, 2], "m")), Unit("m"), id="Variable-units" + {"x": Variable("x", [1, 2])}, {"x": None}, id="Variable-no units" ), - pytest.param(DataArray([1, 2], dims="x"), None, id="DataArray-no units"), pytest.param( - DataArray(Quantity([1, 2], "s"), dims="x"), - Unit("s"), + {"x": Variable("x", Quantity([1, 2], "m"))}, + {"x": Unit("m")}, + id="Variable-units", + ), + pytest.param( + {"x": DataArray([1, 2], dims="x")}, {"x": None}, id="DataArray-no units" + ), + pytest.param( + {"x": DataArray(Quantity([1, 2], "s"), dims="x")}, + {"x": Unit("s")}, id="DataArray-units", ), - pytest.param(slice(None), None, id="empty slice-no units"), - pytest.param(slice(1, None), None, id="slice-no units"), + pytest.param({"x": slice(None)}, {"x": None}, id="empty slice-no units"), + pytest.param({"x": slice(1, None)}, {"x": None}, id="slice-no units"), pytest.param( - slice(Quantity(1, "m"), Quantity(2, "m")), - Unit("m"), + {"x": slice(Quantity(1, "m"), Quantity(2, "m"))}, + {"x": Unit("m")}, id="slice-identical units", ), pytest.param( - slice(Quantity(1, "m"), Quantity(2000, "mm")), - Unit("m"), + {"x": slice(Quantity(1, "m"), Quantity(2000, "mm"))}, + {"x": Unit("m")}, id="slice-compatible units", ), pytest.param( - slice(Quantity(1, "m"), Quantity(2, "ms")), + {"x": slice(Quantity(1, "m"), Quantity(2, "ms"))}, ValueError, id="slice-incompatible units", ), pytest.param( - slice(1, Quantity(2, "ms")), + {"x": slice(1, Quantity(2, "ms"))}, ValueError, id="slice-incompatible units-mixed", ), pytest.param( - slice(1, Quantity(2, "rad")), - Unit("rad"), + {"x": slice(1, Quantity(2, "rad"))}, + {"x": Unit("rad")}, id="slice-incompatible units-mixed-dimensionless", ), ), ) - def test_extract_indexer_units(self, indexer, expected): - if expected is not None and not isinstance(expected, Unit): + def test_extract_indexer_units(self, indexers, expected): + if isinstance(expected, type) and issubclass(expected, Exception): with pytest.raises(expected): - conversion.extract_indexer_units(indexer) + conversion.extract_indexer_units(indexers) else: - actual = conversion.extract_indexer_units(indexer) + actual = conversion.extract_indexer_units(indexers) assert actual == expected @pytest.mark.parametrize( - ["indexer", "expected"], + ["indexers", "expected"], ( - pytest.param(1, 1, id="scalar-no units"), - pytest.param(Quantity(1, "m"), 1, id="scalar-units"), - pytest.param(np.array([1, 2]), np.array([1, 2]), id="array-no units"), - pytest.param(Quantity([1, 2], "s"), np.array([1, 2]), id="array-units"), + pytest.param({"x": 1}, {"x": 1}, id="scalar-no units"), + pytest.param({"x": Quantity(1, "m")}, {"x": 1}, id="scalar-units"), pytest.param( - Variable("x", [1, 2]), Variable("x", [1, 2]), id="Variable-no units" + {"x": np.array([1, 2])}, + {"x": np.array([1, 2])}, + id="array-no units", + ), + pytest.param( + {"x": Quantity([1, 2], "s")}, {"x": np.array([1, 2])}, id="array-units" + ), + pytest.param( + {"x": Variable("x", [1, 2])}, + {"x": Variable("x", [1, 2])}, + id="Variable-no units", ), pytest.param( - Variable("x", Quantity([1, 2], "m")), - Variable("x", [1, 2]), + {"x": Variable("x", Quantity([1, 2], "m"))}, + {"x": Variable("x", [1, 2])}, id="Variable-units", ), pytest.param( - DataArray([1, 2], dims="x"), - DataArray([1, 2], dims="x"), + {"x": DataArray([1, 2], dims="x")}, + {"x": DataArray([1, 2], dims="x")}, id="DataArray-no units", ), pytest.param( - DataArray(Quantity([1, 2], "s"), dims="x"), - DataArray([1, 2], dims="x"), + {"x": DataArray(Quantity([1, 2], "s"), dims="x")}, + {"x": DataArray([1, 2], dims="x")}, id="DataArray-units", ), - pytest.param(slice(None), slice(None), id="empty slice-no units"), - pytest.param(slice(1, None), slice(1, None), id="slice-no units"), pytest.param( - slice(Quantity(1, "m"), Quantity(2, "m")), - slice(1, 2), + {"x": slice(None)}, {"x": slice(None)}, id="empty slice-no units" + ), + pytest.param( + {"x": slice(1, None)}, {"x": slice(1, None)}, id="slice-no units" + ), + pytest.param( + {"x": slice(Quantity(1, "m"), Quantity(2, "m"))}, + {"x": slice(1, 2)}, id="slice-units", ), ), ) - def test_strip_indexer_units(self, indexer, expected): - actual = conversion.strip_indexer_units(indexer) - if isinstance(indexer, DataArray): - assert_identical(actual, expected) - else: - assert_array_equal(actual, expected) + def test_strip_indexer_units(self, indexers, expected): + actual = conversion.strip_indexer_units(indexers) + + assert_indexers_equal(actual, expected) diff --git a/pint_xarray/tests/test_index.py b/pint_xarray/tests/test_index.py new file mode 100644 index 00000000..13f49064 --- /dev/null +++ b/pint_xarray/tests/test_index.py @@ -0,0 +1,227 @@ +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from xarray.core.indexes import IndexSelResult, PandasIndex + +from pint_xarray import unit_registry as ureg +from pint_xarray.index import PintIndex + + +def indexer_equal(first, second): + if type(first) is not type(second): + return False + + if isinstance(first, np.ndarray): + return np.all(first == second) + else: + return first == second + + +@pytest.mark.parametrize( + "base_index", + [ + PandasIndex(pd.Index([1, 2, 3]), dim="x"), + PandasIndex(pd.Index([0.1, 0.2, 0.3]), dim="x"), + PandasIndex(pd.Index([1j, 2j, 3j]), dim="y"), + ], +) +@pytest.mark.parametrize("units", [ureg.Unit("m"), ureg.Unit("s")]) +def test_init(base_index, units): + index = PintIndex(index=base_index, units=units) + + assert index.index.equals(base_index) + assert index.units == units + + +def test_replace(): + old_index = PandasIndex([1, 2, 3], dim="y") + new_index = PandasIndex([0.1, 0.2, 0.3], dim="x") + + old = PintIndex(index=old_index, units=ureg.Unit("m")) + new = old._replace(new_index) + + assert new.index.equals(new_index) + assert new.units == old.units + # no mutation + assert old.index.equals(old_index) + + +@pytest.mark.parametrize( + ["wrapped_index", "units", "expected"], + ( + pytest.param( + PandasIndex(pd.Index([1, 2, 3]), dim="x"), + {"x": ureg.Unit("m")}, + {"x": xr.Variable("x", ureg.Quantity([1, 2, 3], "m"))}, + ), + pytest.param( + PandasIndex(pd.Index([1j, 2j, 3j]), dim="y"), + {"y": ureg.Unit("ms")}, + {"y": xr.Variable("y", ureg.Quantity([1j, 2j, 3j], "ms"))}, + ), + ), +) +def test_create_variables(wrapped_index, units, expected): + index = PintIndex(index=wrapped_index, units=units) + + actual = index.create_variables() + + assert list(actual.keys()) == list(expected.keys()) + assert all([actual[k].equals(expected[k]) for k in expected.keys()]) + + +@pytest.mark.parametrize( + ["labels", "expected"], + ( + ({"x": ureg.Quantity(1, "m")}, IndexSelResult(dim_indexers={"x": 0})), + ({"x": ureg.Quantity(3000, "mm")}, IndexSelResult(dim_indexers={"x": 2})), + ({"x": ureg.Quantity(0.002, "km")}, IndexSelResult(dim_indexers={"x": 1})), + ( + {"x": ureg.Quantity([0.002, 0.004], "km")}, + IndexSelResult(dim_indexers={"x": np.array([1, 3])}), + ), + ( + {"x": slice(ureg.Quantity(2, "m"), ureg.Quantity(3000, "mm"))}, + IndexSelResult(dim_indexers={"x": slice(1, 3)}), + ), + ), +) +def test_sel(labels, expected): + index = PintIndex( + index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), units={"x": ureg.Unit("m")} + ) + + actual = index.sel(labels) + + assert isinstance(actual, IndexSelResult) + assert list(actual.dim_indexers.keys()) == list(expected.dim_indexers.keys()) + assert all( + [ + indexer_equal(actual.dim_indexers[k], expected.dim_indexers[k]) + for k in expected.dim_indexers.keys() + ] + ) + + +@pytest.mark.parametrize( + "indexers", + ({"y": 0}, {"y": [1, 2]}, {"y": slice(0, None, 2)}, {"y": xr.Variable("y", [1])}), +) +def test_isel(indexers): + wrapped_index = PandasIndex(pd.Index([1, 2, 3, 4]), dim="y") + index = PintIndex(index=wrapped_index, units={"y": ureg.Unit("s")}) + + actual = index.isel(indexers) + + wrapped_ = wrapped_index.isel(indexers) + if wrapped_ is not None: + expected = PintIndex( + index=wrapped_index.isel(indexers), units={"y": ureg.Unit("s")} + ) + else: + expected = None + + assert (actual is None and expected is None) or actual.equals(expected) + + +@pytest.mark.parametrize( + ["other", "expected"], + ( + ( + PintIndex( + index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), + units={"x": ureg.Unit("cm")}, + ), + True, + ), + (PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), False), + ( + PintIndex( + index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), + units={"x": ureg.Unit("m")}, + ), + False, + ), + ( + PintIndex( + index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="y"), + units={"y": ureg.Unit("cm")}, + ), + False, + ), + ( + PintIndex( + index=PandasIndex(pd.Index([1, 3, 3, 4]), dim="x"), + units={"x": ureg.Unit("cm")}, + ), + False, + ), + ), +) +def test_equals(other, expected): + index = PintIndex( + index=PandasIndex(pd.Index([1, 2, 3, 4]), dim="x"), units={"x": ureg.Unit("cm")} + ) + + actual = index.equals(other) + + assert actual == expected + + +@pytest.mark.parametrize( + ["shifts", "expected_index"], + ( + ({"x": 0}, PandasIndex(pd.Index([-2, -1, 0, 1, 2]), dim="x")), + ({"x": 1}, PandasIndex(pd.Index([2, -2, -1, 0, 1]), dim="x")), + ({"x": 2}, PandasIndex(pd.Index([1, 2, -2, -1, 0]), dim="x")), + ({"x": -1}, PandasIndex(pd.Index([-1, 0, 1, 2, -2]), dim="x")), + ({"x": -2}, PandasIndex(pd.Index([0, 1, 2, -2, -1]), dim="x")), + ), +) +def test_roll(shifts, expected_index): + index = PintIndex( + index=PandasIndex(pd.Index([-2, -1, 0, 1, 2]), dim="x"), + units={"x": ureg.Unit("m")}, + ) + + actual = index.roll(shifts) + expected = index._replace(expected_index) + + assert actual.equals(expected) + + +@pytest.mark.parametrize("dims_dict", ({"y": "x"}, {"y": "z"})) +@pytest.mark.parametrize("name_dict", ({"y2": "y3"}, {"y2": "y1"})) +def test_rename(name_dict, dims_dict): + wrapped_index = PandasIndex(pd.Index([1, 2], name="y2"), dim="y") + index = PintIndex(index=wrapped_index, units={"y": ureg.Unit("m")}) + + actual = index.rename(name_dict, dims_dict) + expected = PintIndex( + index=wrapped_index.rename(name_dict, dims_dict), units=index.units + ) + + assert actual.equals(expected) + + +@pytest.mark.parametrize("indexer", ([0], slice(0, 2))) +def test_getitem(indexer): + wrapped_index = PandasIndex(pd.Index([1, 2], name="y2"), dim="y") + index = PintIndex(index=wrapped_index, units={"y": ureg.Unit("m")}) + + actual = index[indexer] + expected = PintIndex(index=wrapped_index[indexer], units=index.units) + + assert actual.equals(expected) + + +@pytest.mark.parametrize("wrapped_index", (PandasIndex(pd.Index([1, 2]), dim="x"),)) +def test_repr_inline(wrapped_index): + index = PintIndex(index=wrapped_index, units=ureg.Unit("m")) + + # TODO: parametrize + actual = index._repr_inline_(90) + + assert "PintIndex" in actual + assert wrapped_index.__class__.__name__ in actual diff --git a/pint_xarray/tests/utils.py b/pint_xarray/tests/utils.py index 4da6b0dc..a9d66006 100644 --- a/pint_xarray/tests/utils.py +++ b/pint_xarray/tests/utils.py @@ -1,5 +1,6 @@ import re from contextlib import contextmanager +from textwrap import indent import numpy as np import pytest @@ -97,6 +98,33 @@ def assert_indexer_equal(a, b): assert a_ == b_, f"different values: {a_!r} ←→ {b_!r}" +def assert_indexers_equal(first, second): + __tracebackhide__ = True + # same keys + assert first.keys() == second.keys(), "different keys" + + errors = {} + for name in first: + first_value = first[name] + second_value = second[name] + + try: + assert_indexer_equal(first_value, second_value) + except AssertionError as e: + errors[name] = e + + if errors: + message = "\n".join( + ["indexers are not equal:"] + + [ + f" - {name}:\n{indent(str(error), ' ' * 4)}" + for name, error in errors.items() + ] + ) + + raise AssertionError(message) + + def assert_indexer_units_equal(a, b): __tracebackhide__ = True diff --git a/pyproject.toml b/pyproject.toml index abfd5c7b..727095c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,3 +52,11 @@ skip_gitignore = "true" force_to_top = "true" default_section = "THIRDPARTY" known_first_party = "pint_xarray" + +[tool.coverage.run] +source = ["pint_xarray"] +branch = true + +[tool.coverage.report] +show_missing = true +exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"]