From 94a89e039297b30ff31de277c5eb20a63d33c255 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:34:32 +0000 Subject: [PATCH] Feature/masks (#104) * add masks Co-authored-by: Florian Pinault --- CHANGELOG.md | 2 + src/anemoi/datasets/data/__init__.py | 23 +++++ src/anemoi/datasets/data/dataset.py | 139 ++++++++++++++++++++++----- src/anemoi/datasets/data/forwards.py | 19 ++++ src/anemoi/datasets/data/grids.py | 21 +++- src/anemoi/datasets/data/masked.py | 8 ++ src/anemoi/datasets/data/misc.py | 1 + src/anemoi/datasets/data/stores.py | 6 ++ src/anemoi/datasets/data/subset.py | 2 + src/anemoi/datasets/grids.py | 51 ++++++++++ 10 files changed, 247 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0be682e8..b45f4fbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ Keep it human-readable, your future self will thank you! - Various bug fixes - Control compatibility check in xy/zip - Add `merge` feature +- Add support for storing `supporting_arrays` in checkpoint files +- Allow naming of datasets components - Contributors file (#105) ### Changed diff --git a/src/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/data/__init__.py index 244842f7..3c399da6 100644 --- a/src/anemoi/datasets/data/__init__.py +++ b/src/anemoi/datasets/data/__init__.py @@ -25,7 +25,30 @@ class MissingDateError(Exception): pass +def _convert(x): + + if isinstance(x, list): + return [_convert(a) for a in x] + + if isinstance(x, tuple): + return tuple(_convert(a) for a in x) + + if isinstance(x, dict): + return {k: _convert(v) for k, v in x.items()} + + if x.__class__.__name__ in ("DictConfig", "ListConfig"): + from omegaconf import OmegaConf + + return OmegaConf.to_container(x, resolve=True) + + return x + + def open_dataset(*args, **kwargs): + + # That will get rid of OmegaConf objects + args, kwargs = _convert(args), _convert(kwargs) + ds = _open_dataset(*args, **kwargs) ds = ds.mutate() ds.arguments = {"args": args, "kwargs": kwargs} diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 93139ce6..f9dda392 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -23,8 +23,34 @@ LOG = logging.getLogger(__name__) +def _tidy(v): + if isinstance(v, (list, tuple, set)): + return [_tidy(i) for i in v] + if isinstance(v, dict): + return {k: _tidy(v) for k, v in v.items()} + if isinstance(v, str) and v.startswith("/"): + return os.path.basename(v) + if isinstance(v, datetime.datetime): + return v.isoformat() + if isinstance(v, datetime.date): + return v.isoformat() + if isinstance(v, datetime.timedelta): + return frequency_to_string(v) + + if isinstance(v, Dataset): + # That can happen in the `arguments` + # if a dataset is passed as an argument + return repr(v) + + if isinstance(v, slice): + return (v.start, v.stop, v.step) + + return v + + class Dataset: arguments = {} + _name = None def mutate(self) -> "Dataset": """Give an opportunity to a subclass to return a new Dataset @@ -41,6 +67,21 @@ def _len(self): return len(self) def _subset(self, **kwargs): + + if not kwargs: + return self.mutate() + + name = kwargs.pop("name", None) + result = self.__subset(**kwargs) + result._name = name + + return result + + @property + def name(self): + return self._name + + def __subset(self, **kwargs): if not kwargs: return self.mutate() @@ -254,41 +295,32 @@ def typed_variables(self): return result + def _input_sources(self): + sources = [] + self.collect_input_sources(sources) + return sources + def metadata(self): import anemoi - def tidy(v): - if isinstance(v, (list, tuple, set)): - return [tidy(i) for i in v] - if isinstance(v, dict): - return {k: tidy(v) for k, v in v.items()} - if isinstance(v, str) and v.startswith("/"): - return os.path.basename(v) - if isinstance(v, datetime.datetime): - return v.isoformat() - if isinstance(v, datetime.date): - return v.isoformat() - if isinstance(v, datetime.timedelta): - return frequency_to_string(v) - - if isinstance(v, Dataset): - # That can happen in the `arguments` - # if a dataset is passed as an argument - return repr(v) - - if isinstance(v, slice): - return (v.start, v.stop, v.step) - - return v + _, source_to_arrays = self._supporting_arrays_and_sources() + + sources = [] + for i, source in enumerate(self._input_sources()): + source_metadata = source.dataset_metadata().copy() + source_metadata["supporting_arrays"] = source_to_arrays[id(source)] + sources.append(source_metadata) md = dict( version=anemoi.datasets.__version__, arguments=self.arguments, **self.dataset_metadata(), + sources=sources, + supporting_arrays=source_to_arrays[id(self)], ) try: - return json.loads(json.dumps(tidy(md))) + return json.loads(json.dumps(_tidy(md))) except Exception: LOG.exception("Failed to serialize metadata") pprint.pprint(md) @@ -313,8 +345,67 @@ def dataset_metadata(self): dtype=str(self.dtype), start_date=self.start_date.astype(str), end_date=self.end_date.astype(str), + name=self.name, ) + def _supporting_arrays(self, *path): + + import numpy as np + + def _path(path, name): + return "/".join(str(_) for _ in [*path, name]) + + result = { + _path(path, "latitudes"): self.latitudes, + _path(path, "longitudes"): self.longitudes, + } + collected = [] + + self.collect_supporting_arrays(collected, *path) + + for path, name, array in collected: + assert isinstance(path, tuple) and isinstance(name, str) + assert isinstance(array, np.ndarray) + + name = _path(path, name) + + if name in result: + raise ValueError(f"Duplicate key {name}") + + result[name] = array + + return result + + def supporting_arrays(self): + """Arrays to be saved in the checkpoints""" + arrays, _ = self._supporting_arrays_and_sources() + return arrays + + def _supporting_arrays_and_sources(self): + + source_to_arrays = {} + + # Top levels arrays + result = self._supporting_arrays() + source_to_arrays[id(self)] = sorted(result.keys()) + + # Arrays from the input sources + for i, source in enumerate(self._input_sources()): + name = source.name if source.name is not None else i + src_arrays = source._supporting_arrays(name) + source_to_arrays[id(source)] = sorted(src_arrays.keys()) + + for k in src_arrays: + assert k not in result + + result.update(src_arrays) + + return result, source_to_arrays + + def collect_supporting_arrays(self, collected, *path): + # Override this method to add more arrays + pass + def metadata_specific(self, **kwargs): action = self.__class__.__name__.lower() # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action) diff --git a/src/anemoi/datasets/data/forwards.py b/src/anemoi/datasets/data/forwards.py index 9f0cd003..e2c77b42 100644 --- a/src/anemoi/datasets/data/forwards.py +++ b/src/anemoi/datasets/data/forwards.py @@ -9,6 +9,7 @@ import logging +import warnings from functools import cached_property import numpy as np @@ -34,6 +35,12 @@ def __len__(self): def __getitem__(self, n): return self.forward[n] + @property + def name(self): + if self._name is not None: + return self._name + return self.forward.name + @property def dates(self): return self.forward.dates @@ -102,6 +109,12 @@ def metadata_specific(self, **kwargs): **kwargs, ) + def collect_supporting_arrays(self, collected, *path): + self.forward.collect_supporting_arrays(collected, *path) + + def collect_input_sources(self, collected): + self.forward.collect_input_sources(collected) + def source(self, index): return self.forward.source(index) @@ -197,6 +210,12 @@ def metadata_specific(self, **kwargs): **kwargs, ) + def collect_supporting_arrays(self, collected, *path): + warnings.warn(f"The behaviour of {self.__class__.__name__}.collect_supporting_arrays() is not well defined") + for i, d in enumerate(self.datasets): + name = d.name if d.name is not None else i + d.collect_supporting_arrays(collected, *path, name) + @property def missing(self): raise NotImplementedError("missing() not implemented for Combined") diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/data/grids.py index 23f32b82..e8859cdb 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/data/grids.py @@ -108,6 +108,17 @@ def check_same_resolution(self, d1, d2): # We don't check the resolution, because we want to be able to combine pass + def metadata_specific(self): + return super().metadata_specific( + multi_grids=True, + ) + + def collect_input_sources(self, collected): + # We assume that,because they have different grids, they have different input sources + for d in self.datasets: + collected.append(d) + d.collect_input_sources(collected) + class Grids(GridsBase): # TODO: select the statistics of the most global grid? @@ -157,6 +168,9 @@ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, self.globe.shape[3], ) + def collect_supporting_arrays(self, collected, *path): + collected.append((path, "cutout_mask", self.mask)) + @cached_property def shape(self): shape = self.lam.shape @@ -212,6 +226,11 @@ def grids(self): def tree(self): return Node(self, [d.tree() for d in self.datasets]) + # def metadata_specific(self): + # return super().metadata_specific( + # mask=serialise_mask(self.mask), + # ) + def grids_factory(args, kwargs): if "ensemble" in kwargs: @@ -241,7 +260,7 @@ def cutout_factory(args, kwargs): neighbours = kwargs.pop("neighbours", 5) assert len(args) == 0 - assert isinstance(cutout, (list, tuple)) + assert isinstance(cutout, (list, tuple)), "cutout must be a list or tuple" datasets = [_open(e) for e in cutout] datasets, kwargs = _auto_adjust(datasets, kwargs) diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/data/masked.py index 2b786cb0..a8cb784b 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/data/masked.py @@ -33,6 +33,8 @@ def __init__(self, forward, mask): self.mask = mask self.axis = 3 + self.mask_name = f"{self.__class__.__name__.lower()}_mask" + @cached_property def shape(self): return self.forward.shape[:-1] + (np.count_nonzero(self.mask),) @@ -67,8 +69,13 @@ def _get_tuple(self, index): result = apply_index_to_slices_changes(result, changes) return result + def collect_supporting_arrays(self, collected, *path): + super().collect_supporting_arrays(collected, *path) + collected.append((path, self.mask_name, self.mask)) + class Thinning(Masked): + def __init__(self, forward, thinning, method): self.thinning = thinning self.method = method @@ -110,6 +117,7 @@ def subclass_metadata_specific(self): class Cropping(Masked): + def __init__(self, forward, area): from ..data import open_dataset diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index 89ddb030..aad751f0 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -270,6 +270,7 @@ def _auto_adjust(datasets, kwargs): def _open_dataset(*args, **kwargs): + sets = [] for a in args: sets.append(_open(a)) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 78647f2a..0484e822 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -344,6 +344,12 @@ def get_dataset_names(self, names): name, _ = os.path.splitext(os.path.basename(self.path)) names.add(name) + def collect_supporting_arrays(self, collected, *path): + pass + + def collect_input_sources(self, collected): + pass + class ZarrWithMissingDates(Zarr): """A zarr dataset with missing dates.""" diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index e31f25d8..8ac60502 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -135,6 +135,8 @@ def dates(self): @cached_property def frequency(self): dates = self.dates + if len(dates) < 2: + raise ValueError(f"Cannot determine frequency of a subset with less than two dates ({self.dates}).") return frequency_to_timedelta(dates[1].astype(object) - dates[0].astype(object)) def source(self, index): diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 09857ad2..9e3a10b6 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. +import base64 import logging import numpy as np @@ -328,6 +329,56 @@ def outline(lats, lons, neighbours=5): return outside +def deserialise_mask(encoded): + import pickle + import zlib + + packed = pickle.loads(zlib.decompress(base64.b64decode(encoded))) + + mask = [] + value = False + for count in packed: + mask.extend([value] * count) + value = not value + return np.array(mask, dtype=bool) + + +def _serialise_mask(mask): + import pickle + import zlib + + assert len(mask.shape) == 1 + assert len(mask) + + packed = [] + last = mask[0] + count = 1 + + for value in mask[1:]: + if value == last: + count += 1 + else: + packed.append(count) + last = value + count = 1 + + packed.append(count) + + # We always start with an 'off' value + # So if the first value is 'on', we need to add a zero + if mask[0]: + packed.insert(0, 0) + + return base64.b64encode(zlib.compress(pickle.dumps(packed))).decode("utf-8") + + +def serialise_mask(mask): + result = _serialise_mask(mask) + # Make sure we can deserialise it + assert np.all(mask == deserialise_mask(result)) + return result + + if __name__ == "__main__": global_lats, global_lons = np.meshgrid( np.linspace(90, -90, 90),