Skip to content

Commit 7a3a4c6

Browse files
kmuehlbauerpre-commit-ci[bot]spencerkclarkdcherian
authored
Improve handling of dtype and NaT when encoding/decoding masked and packaged datetimes and timedeltas (#10050)
* mask/scale datetimes/timedeltas only if they will be decoded, better handle partial coding * comments * typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typing and imports * refactor according to review concerns and suggestions * retain retain dtype for packed data in datetime/timedelta encoding * simplify code, add whats-new.rst entry * Update xarray/coding/variables.py Co-authored-by: Spencer Clark <[email protected]> * refactor common code into common.py to prevent circular imports when passing decode_times and decode_timedelta to CFMaskCoder and CFScaleOffsetCoder * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Spencer Clark <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 70d2f1d commit 7a3a4c6

File tree

7 files changed

+309
-163
lines changed

7 files changed

+309
-163
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ Bug fixes
7171
By `Benoit Bovy <https://github.com/benbovy>`_.
7272
- Fix dask tokenization when opening each node in :py:func:`xarray.open_datatree`
7373
(:issue:`10098`, :pull:`10100`). By `Sam Levang <https://github.com/slevang>`_.
74+
- Improve handling of dtype and NaT when encoding/decoding masked and packaged
75+
datetimes and timedeltas (:issue:`8957`, :pull:`10050`).
76+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
7477

7578
Documentation
7679
~~~~~~~~~~~~~

xarray/coding/common.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable, Hashable, MutableMapping
4+
from typing import TYPE_CHECKING, Any, Union
5+
6+
import numpy as np
7+
8+
from xarray.core import indexing
9+
from xarray.core.variable import Variable
10+
from xarray.namedarray.parallelcompat import get_chunked_array_type
11+
from xarray.namedarray.pycompat import is_chunked_array
12+
13+
if TYPE_CHECKING:
14+
T_VarTuple = tuple[tuple[Hashable, ...], Any, dict, dict]
15+
T_Name = Union[Hashable, None]
16+
17+
18+
class SerializationWarning(RuntimeWarning):
19+
"""Warnings about encoding/decoding issues in serialization."""
20+
21+
22+
class VariableCoder:
23+
"""Base class for encoding and decoding transformations on variables.
24+
25+
We use coders for transforming variables between xarray's data model and
26+
a format suitable for serialization. For example, coders apply CF
27+
conventions for how data should be represented in netCDF files.
28+
29+
Subclasses should implement encode() and decode(), which should satisfy
30+
the identity ``coder.decode(coder.encode(variable)) == variable``. If any
31+
options are necessary, they should be implemented as arguments to the
32+
__init__ method.
33+
34+
The optional name argument to encode() and decode() exists solely for the
35+
sake of better error messages, and should correspond to the name of
36+
variables in the underlying store.
37+
"""
38+
39+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
40+
"""Convert an encoded variable to a decoded variable"""
41+
raise NotImplementedError()
42+
43+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
44+
"""Convert a decoded variable to an encoded variable"""
45+
raise NotImplementedError()
46+
47+
48+
class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
49+
"""Lazily computed array holding values of elemwise-function.
50+
51+
Do not construct this object directly: call lazy_elemwise_func instead.
52+
53+
Values are computed upon indexing or coercion to a NumPy array.
54+
"""
55+
56+
def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike):
57+
assert not is_chunked_array(array)
58+
self.array = indexing.as_indexable(array)
59+
self.func = func
60+
self._dtype = dtype
61+
62+
@property
63+
def dtype(self) -> np.dtype:
64+
return np.dtype(self._dtype)
65+
66+
def _oindex_get(self, key):
67+
return type(self)(self.array.oindex[key], self.func, self.dtype)
68+
69+
def _vindex_get(self, key):
70+
return type(self)(self.array.vindex[key], self.func, self.dtype)
71+
72+
def __getitem__(self, key):
73+
return type(self)(self.array[key], self.func, self.dtype)
74+
75+
def get_duck_array(self):
76+
return self.func(self.array.get_duck_array())
77+
78+
def __repr__(self) -> str:
79+
return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})"
80+
81+
82+
def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike):
83+
"""Lazily apply an element-wise function to an array.
84+
Parameters
85+
----------
86+
array : any valid value of Variable._data
87+
func : callable
88+
Function to apply to indexed slices of an array. For use with dask,
89+
this should be a pickle-able object.
90+
dtype : coercible to np.dtype
91+
Dtype for the result of this function.
92+
93+
Returns
94+
-------
95+
Either a dask.array.Array or _ElementwiseFunctionArray.
96+
"""
97+
if is_chunked_array(array):
98+
chunkmanager = get_chunked_array_type(array)
99+
100+
return chunkmanager.map_blocks(func, array, dtype=dtype) # type: ignore[arg-type]
101+
else:
102+
return _ElementwiseFunctionArray(array, func, dtype)
103+
104+
105+
def safe_setitem(dest, key: Hashable, value, name: T_Name = None):
106+
if key in dest:
107+
var_str = f" on variable {name!r}" if name else ""
108+
raise ValueError(
109+
f"failed to prevent overwriting existing key {key} in attrs{var_str}. "
110+
"This is probably an encoding field used by xarray to describe "
111+
"how a variable is serialized. To proceed, remove this key from "
112+
"the variable's attributes manually."
113+
)
114+
dest[key] = value
115+
116+
117+
def pop_to(
118+
source: MutableMapping, dest: MutableMapping, key: Hashable, name: T_Name = None
119+
) -> Any:
120+
"""
121+
A convenience function which pops a key k from source to dest.
122+
None values are not passed on. If k already exists in dest an
123+
error is raised.
124+
"""
125+
value = source.pop(key, None)
126+
if value is not None:
127+
safe_setitem(dest, key, value, name=name)
128+
return value
129+
130+
131+
def unpack_for_encoding(var: Variable) -> T_VarTuple:
132+
return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
133+
134+
135+
def unpack_for_decoding(var: Variable) -> T_VarTuple:
136+
return var.dims, var._data, var.attrs.copy(), var.encoding.copy()

xarray/coding/times.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pandas as pd
1212
from pandas.errors import OutOfBoundsDatetime, OutOfBoundsTimedelta
1313

14-
from xarray.coding.variables import (
14+
from xarray.coding.common import (
1515
SerializationWarning,
1616
VariableCoder,
1717
lazy_elemwise_func,
@@ -1328,9 +1328,20 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
13281328

13291329
units = encoding.pop("units", None)
13301330
calendar = encoding.pop("calendar", None)
1331-
dtype = encoding.get("dtype", None)
1331+
dtype = encoding.pop("dtype", None)
1332+
1333+
# in the case of packed data we need to encode into
1334+
# float first, the correct dtype will be established
1335+
# via CFScaleOffsetCoder/CFMaskCoder
1336+
set_dtype_encoding = None
1337+
if "add_offset" in encoding or "scale_factor" in encoding:
1338+
set_dtype_encoding = dtype
1339+
dtype = data.dtype if data.dtype.kind == "f" else "float64"
13321340
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)
13331341

1342+
# retain dtype for packed data
1343+
if set_dtype_encoding is not None:
1344+
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)
13341345
safe_setitem(attrs, "units", units, name=name)
13351346
safe_setitem(attrs, "calendar", calendar, name=name)
13361347

@@ -1382,9 +1393,22 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
13821393
if np.issubdtype(variable.data.dtype, np.timedelta64):
13831394
dims, data, attrs, encoding = unpack_for_encoding(variable)
13841395

1385-
data, units = encode_cf_timedelta(
1386-
data, encoding.pop("units", None), encoding.get("dtype", None)
1387-
)
1396+
dtype = encoding.pop("dtype", None)
1397+
1398+
# in the case of packed data we need to encode into
1399+
# float first, the correct dtype will be established
1400+
# via CFScaleOffsetCoder/CFMaskCoder
1401+
set_dtype_encoding = None
1402+
if "add_offset" in encoding or "scale_factor" in encoding:
1403+
set_dtype_encoding = dtype
1404+
dtype = data.dtype if data.dtype.kind == "f" else "float64"
1405+
1406+
data, units = encode_cf_timedelta(data, encoding.pop("units", None), dtype)
1407+
1408+
# retain dtype for packed data
1409+
if set_dtype_encoding is not None:
1410+
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)
1411+
13881412
safe_setitem(attrs, "units", units, name=name)
13891413

13901414
return Variable(dims, data, attrs, encoding, fastpath=True)

0 commit comments

Comments
 (0)