From 87a7b9776e780522e67729443f779335b755bd39 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:07:09 +0000 Subject: [PATCH] Feature/new checkpoints (#107) * add masks * save masks to checkpoint * name supporting_arrays * better support for cutout * force np.datetime64 is seconds --------- Co-authored-by: Florian Pinault --- .../datasets/create/functions/__init__.py | 2 + .../datasets/create/functions/sources/mars.py | 86 +++++++++++++++++-- src/anemoi/datasets/data/__init__.py | 1 + src/anemoi/datasets/data/dataset.py | 5 +- src/anemoi/datasets/data/merge.py | 12 ++- src/anemoi/datasets/data/stores.py | 3 +- 6 files changed, 95 insertions(+), 14 deletions(-) diff --git a/src/anemoi/datasets/create/functions/__init__.py b/src/anemoi/datasets/create/functions/__init__.py index af894893..fce2ad86 100644 --- a/src/anemoi/datasets/create/functions/__init__.py +++ b/src/anemoi/datasets/create/functions/__init__.py @@ -21,6 +21,8 @@ def assert_is_fieldlist(obj): def import_function(name, kind): + from anemoi.transforms import Transform as Transform + name = name.replace("-", "_") plugins = {} diff --git a/src/anemoi/datasets/create/functions/sources/mars.py b/src/anemoi/datasets/create/functions/sources/mars.py index 703c3708..ef72888c 100644 --- a/src/anemoi/datasets/create/functions/sources/mars.py +++ b/src/anemoi/datasets/create/functions/sources/mars.py @@ -8,6 +8,7 @@ # nor does it submit to any jurisdiction. import datetime +import re from anemoi.utils.humanize import did_you_mean from earthkit.data import from_source @@ -32,6 +33,25 @@ def _date_to_datetime(d): return datetime.datetime.fromisoformat(d) +def expand_to_by(x): + + if isinstance(x, (str, int)): + return expand_to_by(str(x).split("/")) + + if len(x) == 3 and x[1] == "to": + start = int(x[0]) + end = int(x[2]) + return list(range(start, end + 1)) + + if len(x) == 5 and x[1] == "to" and x[3] == "by": + start = int(x[0]) + end = int(x[2]) + by = int(x[4]) + return list(range(start, end + 1, by)) + + return x + + def normalise_time_delta(t): if isinstance(t, datetime.timedelta): assert t == datetime.timedelta(hours=t.hours), t @@ -43,25 +63,48 @@ def normalise_time_delta(t): return t +def _normalise_time(t): + t = int(t) + if t < 100: + t * 100 + return "{:04d}".format(t) + + def _expand_mars_request(request, date, request_already_using_valid_datetime=False, date_key="date"): requests = [] - step = to_list(request.get("step", [0])) - for s in step: + + user_step = to_list(expand_to_by(request.get("step", [0]))) + user_time = None + user_date = None + + if not request_already_using_valid_datetime: + user_time = request.get("time") + if user_time is not None: + user_time = to_list(user_time) + user_time = [_normalise_time(t) for t in user_time] + + user_date = request.get(date_key) + if user_date is not None: + assert isinstance(user_date, str), user_date + user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", "."))) + + for step in user_step: r = request.copy() if not request_already_using_valid_datetime: - if isinstance(s, str) and "-" in s: - assert s.count("-") == 1, s + if isinstance(step, str) and "-" in step: + assert step.count("-") == 1, step + # this takes care of the cases where the step is a period such as 0-24 or 12-24 - hours = int(str(s).split("-")[-1]) + hours = int(str(step).split("-")[-1]) base = date - datetime.timedelta(hours=hours) r.update( { date_key: base.strftime("%Y%m%d"), "time": base.strftime("%H%M"), - "step": s, + "step": step, } ) @@ -70,12 +113,28 @@ def _expand_mars_request(request, date, request_already_using_valid_datetime=Fal if isinstance(r[pproc], (list, tuple)): r[pproc] = "/".join(str(x) for x in r[pproc]) + if user_date is not None: + if not user_date.match(r[date_key]): + continue + + if user_time is not None: + # It time is provided by the user, we only keep the requests that match the time + if r["time"] not in user_time: + continue + requests.append(r) + # assert requests, requests + return requests -def factorise_requests(dates, *requests, request_already_using_valid_datetime=False, date_key="date"): +def factorise_requests( + dates, + *requests, + request_already_using_valid_datetime=False, + date_key="date", +): updates = [] for req in requests: # req = normalise_request(req) @@ -88,6 +147,9 @@ def factorise_requests(dates, *requests, request_already_using_valid_datetime=Fa date_key=date_key, ) + if not updates: + return + compressed = Availability(updates) for r in compressed.iterate(): for k, v in r.items(): @@ -178,7 +240,15 @@ def use_grib_paramid(r): ] -def mars(context, dates, *requests, request_already_using_valid_datetime=False, date_key="date", **kwargs): +def mars( + context, + dates, + *requests, + request_already_using_valid_datetime=False, + date_key="date", + **kwargs, +): + if not requests: requests = [kwargs] diff --git a/src/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/data/__init__.py index 8c0503ce..1b2a26c6 100644 --- a/src/anemoi/datasets/data/__init__.py +++ b/src/anemoi/datasets/data/__init__.py @@ -49,6 +49,7 @@ def _convert(x): def open_dataset(*args, **kwargs): # That will get rid of OmegaConf objects + args, kwargs = _convert(args), _convert(kwargs) ds = _open_dataset(*args, **kwargs) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index f9dda392..8ab06757 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -11,7 +11,6 @@ import datetime import json import logging -import os import pprint import warnings from functools import cached_property @@ -28,8 +27,6 @@ def _tidy(v): return [_tidy(i) for i in v] if isinstance(v, dict): return {k: _tidy(v) for k, v in v.items()} - if isinstance(v, str) and v.startswith("/"): - return os.path.basename(v) if isinstance(v, datetime.datetime): return v.isoformat() if isinstance(v, datetime.date): @@ -391,7 +388,7 @@ def _supporting_arrays_and_sources(self): # Arrays from the input sources for i, source in enumerate(self._input_sources()): - name = source.name if source.name is not None else i + name = source.name if source.name is not None else f"source{i}" src_arrays = source._supporting_arrays(name) source_to_arrays[id(source)] = sorted(src_arrays.keys()) diff --git a/src/anemoi/datasets/data/merge.py b/src/anemoi/datasets/data/merge.py index 31c08b27..ba038108 100644 --- a/src/anemoi/datasets/data/merge.py +++ b/src/anemoi/datasets/data/merge.py @@ -28,6 +28,13 @@ class Merge(Combined): + + # d0 d2 d4 d6 ... + # d1 d3 d5 d7 ... + + # gives + # d0 d1 d2 d3 ... + def __init__(self, datasets, allow_gaps_in_dates=False): super().__init__(datasets) @@ -71,7 +78,10 @@ def __init__(self, datasets, allow_gaps_in_dates=False): self._dates = np.array(_dates, dtype="datetime64[s]") self._indices = np.array(indices) - self._frequency = frequency.astype(object) + self._frequency = frequency # .astype(object) + + def __len__(self): + return len(self._dates) @property def dates(self): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index f2c69ae1..c8340e6c 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -364,6 +364,7 @@ def metadata_specific(self): attrs=dict(self.z.attrs), chunks=self.chunks, dtype=str(self.dtype), + path=self.path, ) def source(self, index): @@ -396,7 +397,7 @@ def __init__(self, path): super().__init__(path) missing_dates = self.z.attrs.get("missing_dates", []) - missing_dates = set([np.datetime64(x) for x in missing_dates]) + missing_dates = set([np.datetime64(x, "s") for x in missing_dates]) self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates} self.missing = set(self.missing_to_dates)