Skip to content

Commit

Permalink
Feature/new checkpoints (#107)
Browse files Browse the repository at this point in the history
* add masks
* save masks to checkpoint
* name supporting_arrays
* better support for cutout
* force np.datetime64 is seconds
---------

Co-authored-by: Florian Pinault <[email protected]>
  • Loading branch information
b8raoult and floriankrb authored Nov 14, 2024
1 parent 4c0213e commit 87a7b97
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/anemoi/datasets/create/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def assert_is_fieldlist(obj):

def import_function(name, kind):

from anemoi.transforms import Transform as Transform

name = name.replace("-", "_")

plugins = {}
Expand Down
86 changes: 78 additions & 8 deletions src/anemoi/datasets/create/functions/sources/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# nor does it submit to any jurisdiction.

import datetime
import re

from anemoi.utils.humanize import did_you_mean
from earthkit.data import from_source
Expand All @@ -32,6 +33,25 @@ def _date_to_datetime(d):
return datetime.datetime.fromisoformat(d)


def expand_to_by(x):

if isinstance(x, (str, int)):
return expand_to_by(str(x).split("/"))

if len(x) == 3 and x[1] == "to":
start = int(x[0])
end = int(x[2])
return list(range(start, end + 1))

if len(x) == 5 and x[1] == "to" and x[3] == "by":
start = int(x[0])
end = int(x[2])
by = int(x[4])
return list(range(start, end + 1, by))

return x


def normalise_time_delta(t):
if isinstance(t, datetime.timedelta):
assert t == datetime.timedelta(hours=t.hours), t
Expand All @@ -43,25 +63,48 @@ def normalise_time_delta(t):
return t


def _normalise_time(t):
t = int(t)
if t < 100:
t * 100
return "{:04d}".format(t)


def _expand_mars_request(request, date, request_already_using_valid_datetime=False, date_key="date"):
requests = []
step = to_list(request.get("step", [0]))
for s in step:

user_step = to_list(expand_to_by(request.get("step", [0])))
user_time = None
user_date = None

if not request_already_using_valid_datetime:
user_time = request.get("time")
if user_time is not None:
user_time = to_list(user_time)
user_time = [_normalise_time(t) for t in user_time]

user_date = request.get(date_key)
if user_date is not None:
assert isinstance(user_date, str), user_date
user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", ".")))

for step in user_step:
r = request.copy()

if not request_already_using_valid_datetime:

if isinstance(s, str) and "-" in s:
assert s.count("-") == 1, s
if isinstance(step, str) and "-" in step:
assert step.count("-") == 1, step

# this takes care of the cases where the step is a period such as 0-24 or 12-24
hours = int(str(s).split("-")[-1])
hours = int(str(step).split("-")[-1])

base = date - datetime.timedelta(hours=hours)
r.update(
{
date_key: base.strftime("%Y%m%d"),
"time": base.strftime("%H%M"),
"step": s,
"step": step,
}
)

Expand All @@ -70,12 +113,28 @@ def _expand_mars_request(request, date, request_already_using_valid_datetime=Fal
if isinstance(r[pproc], (list, tuple)):
r[pproc] = "/".join(str(x) for x in r[pproc])

if user_date is not None:
if not user_date.match(r[date_key]):
continue

if user_time is not None:
# It time is provided by the user, we only keep the requests that match the time
if r["time"] not in user_time:
continue

requests.append(r)

# assert requests, requests

return requests


def factorise_requests(dates, *requests, request_already_using_valid_datetime=False, date_key="date"):
def factorise_requests(
dates,
*requests,
request_already_using_valid_datetime=False,
date_key="date",
):
updates = []
for req in requests:
# req = normalise_request(req)
Expand All @@ -88,6 +147,9 @@ def factorise_requests(dates, *requests, request_already_using_valid_datetime=Fa
date_key=date_key,
)

if not updates:
return

compressed = Availability(updates)
for r in compressed.iterate():
for k, v in r.items():
Expand Down Expand Up @@ -178,7 +240,15 @@ def use_grib_paramid(r):
]


def mars(context, dates, *requests, request_already_using_valid_datetime=False, date_key="date", **kwargs):
def mars(
context,
dates,
*requests,
request_already_using_valid_datetime=False,
date_key="date",
**kwargs,
):

if not requests:
requests = [kwargs]

Expand Down
1 change: 1 addition & 0 deletions src/anemoi/datasets/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _convert(x):
def open_dataset(*args, **kwargs):

# That will get rid of OmegaConf objects

args, kwargs = _convert(args), _convert(kwargs)

ds = _open_dataset(*args, **kwargs)
Expand Down
5 changes: 1 addition & 4 deletions src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import datetime
import json
import logging
import os
import pprint
import warnings
from functools import cached_property
Expand All @@ -28,8 +27,6 @@ def _tidy(v):
return [_tidy(i) for i in v]
if isinstance(v, dict):
return {k: _tidy(v) for k, v in v.items()}
if isinstance(v, str) and v.startswith("/"):
return os.path.basename(v)
if isinstance(v, datetime.datetime):
return v.isoformat()
if isinstance(v, datetime.date):
Expand Down Expand Up @@ -391,7 +388,7 @@ def _supporting_arrays_and_sources(self):

# Arrays from the input sources
for i, source in enumerate(self._input_sources()):
name = source.name if source.name is not None else i
name = source.name if source.name is not None else f"source{i}"
src_arrays = source._supporting_arrays(name)
source_to_arrays[id(source)] = sorted(src_arrays.keys())

Expand Down
12 changes: 11 additions & 1 deletion src/anemoi/datasets/data/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@


class Merge(Combined):

# d0 d2 d4 d6 ...
# d1 d3 d5 d7 ...

# gives
# d0 d1 d2 d3 ...

def __init__(self, datasets, allow_gaps_in_dates=False):
super().__init__(datasets)

Expand Down Expand Up @@ -71,7 +78,10 @@ def __init__(self, datasets, allow_gaps_in_dates=False):

self._dates = np.array(_dates, dtype="datetime64[s]")
self._indices = np.array(indices)
self._frequency = frequency.astype(object)
self._frequency = frequency # .astype(object)

def __len__(self):
return len(self._dates)

@property
def dates(self):
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def metadata_specific(self):
attrs=dict(self.z.attrs),
chunks=self.chunks,
dtype=str(self.dtype),
path=self.path,
)

def source(self, index):
Expand Down Expand Up @@ -396,7 +397,7 @@ def __init__(self, path):
super().__init__(path)

missing_dates = self.z.attrs.get("missing_dates", [])
missing_dates = set([np.datetime64(x) for x in missing_dates])
missing_dates = set([np.datetime64(x, "s") for x in missing_dates])
self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates}
self.missing = set(self.missing_to_dates)

Expand Down

0 comments on commit 87a7b97

Please sign in to comment.