Skip to content

Improve handling of dtype and NaT when encoding/decoding masked and packaged datetimes and timedeltas #10050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Bug fixes
Haacker <https://github.com/j-haacker>`_.
- Fix ``isel`` for multi-coordinate Xarray indexes (:issue:`10063`, :pull:`10066`).
By `Benoit Bovy <https://github.com/benbovy>`_.
- Improve handling of dtype and NaT when encoding/decoding masked and packaged datetimes and timedeltas (:issue:`8957`, :pull:`10050`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.


Documentation
Expand Down
32 changes: 28 additions & 4 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,9 +1315,20 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:

units = encoding.pop("units", None)
calendar = encoding.pop("calendar", None)
dtype = encoding.get("dtype", None)
dtype = encoding.pop("dtype", None)

# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
set_dtype_encoding = None
if "add_offset" in encoding or "scale_factor" in encoding:
set_dtype_encoding = dtype
dtype = data.dtype if data.dtype.kind == "f" else "float64"
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)

# retain dtype for packed data
if set_dtype_encoding is not None:
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)
safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)

Expand Down Expand Up @@ -1369,9 +1380,22 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)

data, units = encode_cf_timedelta(
data, encoding.pop("units", None), encoding.get("dtype", None)
)
dtype = encoding.pop("dtype", None)

# in the case of packed data we need to encode into
# float first, the correct dtype will be established
# via CFScaleOffsetCoder/CFMaskCoder
set_dtype_encoding = None
if "add_offset" in encoding or "scale_factor" in encoding:
set_dtype_encoding = dtype
dtype = data.dtype if data.dtype.kind == "f" else "float64"

data, units = encode_cf_timedelta(data, encoding.pop("units", None), dtype)

# retain dtype for packed data
if set_dtype_encoding is not None:
safe_setitem(encoding, "dtype", set_dtype_encoding, name=name)

safe_setitem(attrs, "units", units, name=name)

return Variable(dims, data, attrs, encoding, fastpath=True)
Expand Down
85 changes: 67 additions & 18 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def _apply_mask(

def _is_time_like(units):
# test for time-like
# return "datetime" for datetetime-like
# return "timedelta" for timedelta-like
if units is None:
return False
time_strings = [
Expand All @@ -255,9 +257,9 @@ def _is_time_like(units):
_unpack_netcdf_time_units(units)
except ValueError:
return False
return True
return "datetime"
else:
return any(tstr == units for tstr in time_strings)
return "timedelta" if any(tstr == units for tstr in time_strings) else False


def _check_fill_values(attrs, name, dtype):
Expand Down Expand Up @@ -367,6 +369,14 @@ def _encode_unsigned_fill_value(
class CFMaskCoder(VariableCoder):
"""Mask or unmask fill values according to CF conventions."""

def __init__(
self,
decode_times: bool = False,
decode_timedelta: bool = False,
) -> None:
self.decode_times = decode_times
self.decode_timedelta = decode_timedelta

def encode(self, variable: Variable, name: T_Name = None):
dims, data, attrs, encoding = unpack_for_encoding(variable)

Expand All @@ -393,33 +403,50 @@ def encode(self, variable: Variable, name: T_Name = None):

if fv_exists:
# Ensure _FillValue is cast to same dtype as data's
# but not for packed data
encoding["_FillValue"] = (
_encode_unsigned_fill_value(name, fv, dtype)
if has_unsigned
else dtype.type(fv)
if "add_offset" not in encoding and "scale_factor" not in encoding
else fv
)
fill_value = pop_to(encoding, attrs, "_FillValue", name=name)

if mv_exists:
# try to use _FillValue, if it exists to align both values
# or use missing_value and ensure it's cast to same dtype as data's
# but not for packed data
encoding["missing_value"] = attrs.get(
"_FillValue",
(
_encode_unsigned_fill_value(name, mv, dtype)
if has_unsigned
else dtype.type(mv)
if "add_offset" not in encoding and "scale_factor" not in encoding
else mv
),
)
fill_value = pop_to(encoding, attrs, "missing_value", name=name)

# apply fillna
if fill_value is not None and not pd.isnull(fill_value):
# special case DateTime to properly handle NaT
if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
data = duck_array_ops.where(
data != np.iinfo(np.int64).min, data, fill_value
)
if _is_time_like(attrs.get("units")):
if data.dtype.kind in "iu":
data = duck_array_ops.where(
data != np.iinfo(np.int64).min, data, fill_value
)
else:
# if we have float data (data was packed prior masking)
# we just fillna
data = duck_array_ops.fillna(data, fill_value)
# but if the fill_value is of integer type
# we need to round and cast
if np.array(fill_value).dtype.kind in "iu":
data = duck_array_ops.astype(
duck_array_ops.around(data), type(fill_value)
)
else:
data = duck_array_ops.fillna(data, fill_value)

Expand Down Expand Up @@ -457,19 +484,28 @@ def decode(self, variable: Variable, name: T_Name = None):
)

if encoded_fill_values:
# special case DateTime to properly handle NaT
dtype: np.typing.DTypeLike
decoded_fill_value: Any
if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min
# in case of packed data we have to decode into float
# in any case
if "scale_factor" in attrs or "add_offset" in attrs:
dtype, decoded_fill_value = (
_choose_float_dtype(data.dtype, attrs),
np.nan,
)
else:
if "scale_factor" not in attrs and "add_offset" not in attrs:
dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)
# in case of no-packing special case DateTime/Timedelta to properly
# handle NaT, we need to check if time-like will be decoded
# or not in further processing
is_time_like = _is_time_like(attrs.get("units"))
if (
(is_time_like == "datetime" and self.decode_times)
or (is_time_like == "timedelta" and self.decode_timedelta)
) and data.dtype.kind in "iu":
dtype = np.int64
decoded_fill_value = np.iinfo(np.int64).min
else:
dtype, decoded_fill_value = (
_choose_float_dtype(data.dtype, attrs),
np.nan,
)
dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)

transform = partial(
_apply_mask,
Expand Down Expand Up @@ -549,6 +585,14 @@ class CFScaleOffsetCoder(VariableCoder):
decode_values = encoded_values * scale_factor + add_offset
"""

def __init__(
self,
decode_times: bool = False,
decode_timedelta: bool = False,
) -> None:
self.decode_times = decode_times
self.decode_timedelta = decode_timedelta

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
dims, data, attrs, encoding = unpack_for_encoding(variable)

Expand Down Expand Up @@ -578,11 +622,16 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
scale_factor = np.asarray(scale_factor).item()
if np.ndim(add_offset) > 0:
add_offset = np.asarray(add_offset).item()
# if we have a _FillValue/masked_value we already have the wanted
# if we have a _FillValue/masked_value in encoding we already have the wanted
# floating point dtype here (via CFMaskCoder), so no check is necessary
# only check in other cases
# only check in other cases and for time-like
dtype = data.dtype
if "_FillValue" not in encoding and "missing_value" not in encoding:
is_time_like = _is_time_like(attrs.get("units"))
if (
("_FillValue" not in encoding and "missing_value" not in encoding)
or (is_time_like == "datetime" and self.decode_times)
or (is_time_like == "timedelta" and self.decode_timedelta)
):
dtype = _choose_float_dtype(dtype, encoding)

transform = partial(
Expand Down
10 changes: 8 additions & 2 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,15 @@ def decode_cf_variable(
var = variables.Numpy2StringDTypeCoder().decode(var)

if mask_and_scale:
dec_times = True if decode_times else False
dec_timedelta = True if decode_timedelta else False
for coder in [
variables.CFMaskCoder(),
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(
decode_times=dec_times, decode_timedelta=dec_timedelta
),
variables.CFScaleOffsetCoder(
decode_times=dec_times, decode_timedelta=dec_timedelta
),
]:
var = coder.decode(var, name=name)

Expand Down
57 changes: 49 additions & 8 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,26 @@ def test_decoded_cf_datetime_array_2d(time_unit: PDDatetimeUnitOptions) -> None:
assert_array_equal(np.asarray(result), expected)


@pytest.mark.parametrize("decode_times", [True, False])
@pytest.mark.parametrize("mask_and_scale", [True, False])
def test_decode_datetime_mask_and_scale(
decode_times: bool, mask_and_scale: bool
) -> None:
attrs = {
"units": "nanoseconds since 1970-01-01",
"calendar": "proleptic_gregorian",
"_FillValue": np.int16(-1),
"add_offset": 100000.0,
}
encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs)
decoded = conventions.decode_cf_variable(
"foo", encoded, mask_and_scale=mask_and_scale, decode_times=decode_times
)
result = conventions.encode_cf_variable(decoded, name="foo")
assert_identical(encoded, result)
assert encoded.dtype == result.dtype


FREQUENCIES_TO_ENCODING_UNITS = {
"ns": "nanoseconds",
"us": "microseconds",
Expand Down Expand Up @@ -637,7 +657,9 @@ def test_cf_timedelta_2d() -> None:


@pytest.mark.parametrize("encoding_unit", FREQUENCIES_TO_ENCODING_UNITS.values())
def test_decode_cf_timedelta_time_unit(time_unit, encoding_unit) -> None:
def test_decode_cf_timedelta_time_unit(
time_unit: PDDatetimeUnitOptions, encoding_unit
) -> None:
encoded = 1
encoding_unit_as_numpy = _netcdf_to_numpy_timeunit(encoding_unit)
if np.timedelta64(1, time_unit) > np.timedelta64(1, encoding_unit_as_numpy):
Expand All @@ -651,7 +673,9 @@ def test_decode_cf_timedelta_time_unit(time_unit, encoding_unit) -> None:
assert result.dtype == expected.dtype


def test_decode_cf_timedelta_time_unit_out_of_bounds(time_unit) -> None:
def test_decode_cf_timedelta_time_unit_out_of_bounds(
time_unit: PDDatetimeUnitOptions,
) -> None:
# Define a scale factor that will guarantee overflow with the given
# time_unit.
scale_factor = np.timedelta64(1, time_unit) // np.timedelta64(1, "ns")
Expand All @@ -660,7 +684,7 @@ def test_decode_cf_timedelta_time_unit_out_of_bounds(time_unit) -> None:
decode_cf_timedelta(encoded, "days", time_unit)


def test_cf_timedelta_roundtrip_large_value(time_unit) -> None:
def test_cf_timedelta_roundtrip_large_value(time_unit: PDDatetimeUnitOptions) -> None:
value = np.timedelta64(np.iinfo(np.int64).max, time_unit)
encoded, units = encode_cf_timedelta(value)
decoded = decode_cf_timedelta(encoded, units, time_unit=time_unit)
Expand Down Expand Up @@ -982,7 +1006,7 @@ def test_use_cftime_default_standard_calendar_out_of_range(
@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS)
@pytest.mark.parametrize("units_year", [1500, 2000, 2500])
def test_use_cftime_default_non_standard_calendar(
calendar, units_year, time_unit
calendar, units_year, time_unit: PDDatetimeUnitOptions
) -> None:
from cftime import num2date

Expand Down Expand Up @@ -1433,9 +1457,9 @@ def test_roundtrip_datetime64_nanosecond_precision_warning(
) -> None:
# test warning if times can't be serialized faithfully
times = [
np.datetime64("1970-01-01T00:01:00", "ns"),
np.datetime64("NaT"),
np.datetime64("1970-01-02T00:01:00", "ns"),
np.datetime64("1970-01-01T00:01:00", time_unit),
np.datetime64("NaT", time_unit),
np.datetime64("1970-01-02T00:01:00", time_unit),
]
units = "days since 1970-01-10T01:01:00"
needed_units = "hours"
Expand Down Expand Up @@ -1624,7 +1648,9 @@ def test_roundtrip_float_times(fill_value, times, units, encoded_values) -> None
_ENCODE_DATETIME64_VIA_DASK_TESTS.values(),
ids=_ENCODE_DATETIME64_VIA_DASK_TESTS.keys(),
)
def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype, time_unit) -> None:
def test_encode_cf_datetime_datetime64_via_dask(
freq, units, dtype, time_unit: PDDatetimeUnitOptions
) -> None:
import dask.array

times_pd = pd.date_range(start="1700", freq=freq, periods=3, unit=time_unit)
Expand Down Expand Up @@ -1907,3 +1933,18 @@ def test_lazy_decode_timedelta_error() -> None:
)
with pytest.raises(OutOfBoundsTimedelta, match="overflow"):
decoded.load()


@pytest.mark.parametrize("decode_timedelta", [True, False])
@pytest.mark.parametrize("mask_and_scale", [True, False])
def test_decode_timedelta_mask_and_scale(
decode_timedelta: bool, mask_and_scale: bool
) -> None:
attrs = {"units": "nanoseconds", "_FillValue": np.int16(-1), "add_offset": 100000.0}
encoded = Variable(["time"], np.array([0, -1, 1], "int16"), attrs=attrs)
decoded = conventions.decode_cf_variable(
"foo", encoded, mask_and_scale=mask_and_scale, decode_timedelta=decode_timedelta
)
result = conventions.encode_cf_variable(decoded, name="foo")
assert_identical(encoded, result)
assert encoded.dtype == result.dtype
5 changes: 1 addition & 4 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,16 +511,13 @@ def test_decode_dask_times(self) -> None:

@pytest.mark.parametrize("time_unit", ["s", "ms", "us", "ns"])
def test_decode_cf_time_kwargs(self, time_unit) -> None:
# todo: if we set timedelta attrs "units": "days"
# this errors on the last decode_cf wrt to the lazy_elemwise_func
# trying to convert twice
ds = Dataset.from_dict(
{
"coords": {
"timedelta": {
"data": np.array([1, 2, 3], dtype="int64"),
"dims": "timedelta",
"attrs": {"units": "seconds"},
"attrs": {"units": "days"},
},
"time": {
"data": np.array([1, 2, 3], dtype="int64"),
Expand Down
Loading