Skip to content

Commit 94a89e0

Browse files
b8raoultfloriankrb
andauthored
Feature/masks (#104)
* add masks Co-authored-by: Florian Pinault <[email protected]>
1 parent 3620e8d commit 94a89e0

File tree

10 files changed

+247
-25
lines changed

10 files changed

+247
-25
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Keep it human-readable, your future self will thank you!
2525
- Various bug fixes
2626
- Control compatibility check in xy/zip
2727
- Add `merge` feature
28+
- Add support for storing `supporting_arrays` in checkpoint files
29+
- Allow naming of datasets components
2830
- Contributors file (#105)
2931

3032
### Changed

src/anemoi/datasets/data/__init__.py

+23
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,30 @@ class MissingDateError(Exception):
2525
pass
2626

2727

28+
def _convert(x):
29+
30+
if isinstance(x, list):
31+
return [_convert(a) for a in x]
32+
33+
if isinstance(x, tuple):
34+
return tuple(_convert(a) for a in x)
35+
36+
if isinstance(x, dict):
37+
return {k: _convert(v) for k, v in x.items()}
38+
39+
if x.__class__.__name__ in ("DictConfig", "ListConfig"):
40+
from omegaconf import OmegaConf
41+
42+
return OmegaConf.to_container(x, resolve=True)
43+
44+
return x
45+
46+
2847
def open_dataset(*args, **kwargs):
48+
49+
# That will get rid of OmegaConf objects
50+
args, kwargs = _convert(args), _convert(kwargs)
51+
2952
ds = _open_dataset(*args, **kwargs)
3053
ds = ds.mutate()
3154
ds.arguments = {"args": args, "kwargs": kwargs}

src/anemoi/datasets/data/dataset.py

+115-24
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,34 @@
2323
LOG = logging.getLogger(__name__)
2424

2525

26+
def _tidy(v):
27+
if isinstance(v, (list, tuple, set)):
28+
return [_tidy(i) for i in v]
29+
if isinstance(v, dict):
30+
return {k: _tidy(v) for k, v in v.items()}
31+
if isinstance(v, str) and v.startswith("/"):
32+
return os.path.basename(v)
33+
if isinstance(v, datetime.datetime):
34+
return v.isoformat()
35+
if isinstance(v, datetime.date):
36+
return v.isoformat()
37+
if isinstance(v, datetime.timedelta):
38+
return frequency_to_string(v)
39+
40+
if isinstance(v, Dataset):
41+
# That can happen in the `arguments`
42+
# if a dataset is passed as an argument
43+
return repr(v)
44+
45+
if isinstance(v, slice):
46+
return (v.start, v.stop, v.step)
47+
48+
return v
49+
50+
2651
class Dataset:
2752
arguments = {}
53+
_name = None
2854

2955
def mutate(self) -> "Dataset":
3056
"""Give an opportunity to a subclass to return a new Dataset
@@ -41,6 +67,21 @@ def _len(self):
4167
return len(self)
4268

4369
def _subset(self, **kwargs):
70+
71+
if not kwargs:
72+
return self.mutate()
73+
74+
name = kwargs.pop("name", None)
75+
result = self.__subset(**kwargs)
76+
result._name = name
77+
78+
return result
79+
80+
@property
81+
def name(self):
82+
return self._name
83+
84+
def __subset(self, **kwargs):
4485
if not kwargs:
4586
return self.mutate()
4687

@@ -254,41 +295,32 @@ def typed_variables(self):
254295

255296
return result
256297

298+
def _input_sources(self):
299+
sources = []
300+
self.collect_input_sources(sources)
301+
return sources
302+
257303
def metadata(self):
258304
import anemoi
259305

260-
def tidy(v):
261-
if isinstance(v, (list, tuple, set)):
262-
return [tidy(i) for i in v]
263-
if isinstance(v, dict):
264-
return {k: tidy(v) for k, v in v.items()}
265-
if isinstance(v, str) and v.startswith("/"):
266-
return os.path.basename(v)
267-
if isinstance(v, datetime.datetime):
268-
return v.isoformat()
269-
if isinstance(v, datetime.date):
270-
return v.isoformat()
271-
if isinstance(v, datetime.timedelta):
272-
return frequency_to_string(v)
273-
274-
if isinstance(v, Dataset):
275-
# That can happen in the `arguments`
276-
# if a dataset is passed as an argument
277-
return repr(v)
278-
279-
if isinstance(v, slice):
280-
return (v.start, v.stop, v.step)
281-
282-
return v
306+
_, source_to_arrays = self._supporting_arrays_and_sources()
307+
308+
sources = []
309+
for i, source in enumerate(self._input_sources()):
310+
source_metadata = source.dataset_metadata().copy()
311+
source_metadata["supporting_arrays"] = source_to_arrays[id(source)]
312+
sources.append(source_metadata)
283313

284314
md = dict(
285315
version=anemoi.datasets.__version__,
286316
arguments=self.arguments,
287317
**self.dataset_metadata(),
318+
sources=sources,
319+
supporting_arrays=source_to_arrays[id(self)],
288320
)
289321

290322
try:
291-
return json.loads(json.dumps(tidy(md)))
323+
return json.loads(json.dumps(_tidy(md)))
292324
except Exception:
293325
LOG.exception("Failed to serialize metadata")
294326
pprint.pprint(md)
@@ -313,8 +345,67 @@ def dataset_metadata(self):
313345
dtype=str(self.dtype),
314346
start_date=self.start_date.astype(str),
315347
end_date=self.end_date.astype(str),
348+
name=self.name,
316349
)
317350

351+
def _supporting_arrays(self, *path):
352+
353+
import numpy as np
354+
355+
def _path(path, name):
356+
return "/".join(str(_) for _ in [*path, name])
357+
358+
result = {
359+
_path(path, "latitudes"): self.latitudes,
360+
_path(path, "longitudes"): self.longitudes,
361+
}
362+
collected = []
363+
364+
self.collect_supporting_arrays(collected, *path)
365+
366+
for path, name, array in collected:
367+
assert isinstance(path, tuple) and isinstance(name, str)
368+
assert isinstance(array, np.ndarray)
369+
370+
name = _path(path, name)
371+
372+
if name in result:
373+
raise ValueError(f"Duplicate key {name}")
374+
375+
result[name] = array
376+
377+
return result
378+
379+
def supporting_arrays(self):
380+
"""Arrays to be saved in the checkpoints"""
381+
arrays, _ = self._supporting_arrays_and_sources()
382+
return arrays
383+
384+
def _supporting_arrays_and_sources(self):
385+
386+
source_to_arrays = {}
387+
388+
# Top levels arrays
389+
result = self._supporting_arrays()
390+
source_to_arrays[id(self)] = sorted(result.keys())
391+
392+
# Arrays from the input sources
393+
for i, source in enumerate(self._input_sources()):
394+
name = source.name if source.name is not None else i
395+
src_arrays = source._supporting_arrays(name)
396+
source_to_arrays[id(source)] = sorted(src_arrays.keys())
397+
398+
for k in src_arrays:
399+
assert k not in result
400+
401+
result.update(src_arrays)
402+
403+
return result, source_to_arrays
404+
405+
def collect_supporting_arrays(self, collected, *path):
406+
# Override this method to add more arrays
407+
pass
408+
318409
def metadata_specific(self, **kwargs):
319410
action = self.__class__.__name__.lower()
320411
# assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)

src/anemoi/datasets/data/forwards.py

+19
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010

1111
import logging
12+
import warnings
1213
from functools import cached_property
1314

1415
import numpy as np
@@ -34,6 +35,12 @@ def __len__(self):
3435
def __getitem__(self, n):
3536
return self.forward[n]
3637

38+
@property
39+
def name(self):
40+
if self._name is not None:
41+
return self._name
42+
return self.forward.name
43+
3744
@property
3845
def dates(self):
3946
return self.forward.dates
@@ -102,6 +109,12 @@ def metadata_specific(self, **kwargs):
102109
**kwargs,
103110
)
104111

112+
def collect_supporting_arrays(self, collected, *path):
113+
self.forward.collect_supporting_arrays(collected, *path)
114+
115+
def collect_input_sources(self, collected):
116+
self.forward.collect_input_sources(collected)
117+
105118
def source(self, index):
106119
return self.forward.source(index)
107120

@@ -197,6 +210,12 @@ def metadata_specific(self, **kwargs):
197210
**kwargs,
198211
)
199212

213+
def collect_supporting_arrays(self, collected, *path):
214+
warnings.warn(f"The behaviour of {self.__class__.__name__}.collect_supporting_arrays() is not well defined")
215+
for i, d in enumerate(self.datasets):
216+
name = d.name if d.name is not None else i
217+
d.collect_supporting_arrays(collected, *path, name)
218+
200219
@property
201220
def missing(self):
202221
raise NotImplementedError("missing() not implemented for Combined")

src/anemoi/datasets/data/grids.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ def check_same_resolution(self, d1, d2):
108108
# We don't check the resolution, because we want to be able to combine
109109
pass
110110

111+
def metadata_specific(self):
112+
return super().metadata_specific(
113+
multi_grids=True,
114+
)
115+
116+
def collect_input_sources(self, collected):
117+
# We assume that,because they have different grids, they have different input sources
118+
for d in self.datasets:
119+
collected.append(d)
120+
d.collect_input_sources(collected)
121+
111122

112123
class Grids(GridsBase):
113124
# TODO: select the statistics of the most global grid?
@@ -157,6 +168,9 @@ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0,
157168
self.globe.shape[3],
158169
)
159170

171+
def collect_supporting_arrays(self, collected, *path):
172+
collected.append((path, "cutout_mask", self.mask))
173+
160174
@cached_property
161175
def shape(self):
162176
shape = self.lam.shape
@@ -212,6 +226,11 @@ def grids(self):
212226
def tree(self):
213227
return Node(self, [d.tree() for d in self.datasets])
214228

229+
# def metadata_specific(self):
230+
# return super().metadata_specific(
231+
# mask=serialise_mask(self.mask),
232+
# )
233+
215234

216235
def grids_factory(args, kwargs):
217236
if "ensemble" in kwargs:
@@ -241,7 +260,7 @@ def cutout_factory(args, kwargs):
241260
neighbours = kwargs.pop("neighbours", 5)
242261

243262
assert len(args) == 0
244-
assert isinstance(cutout, (list, tuple))
263+
assert isinstance(cutout, (list, tuple)), "cutout must be a list or tuple"
245264

246265
datasets = [_open(e) for e in cutout]
247266
datasets, kwargs = _auto_adjust(datasets, kwargs)

src/anemoi/datasets/data/masked.py

+8
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(self, forward, mask):
3333
self.mask = mask
3434
self.axis = 3
3535

36+
self.mask_name = f"{self.__class__.__name__.lower()}_mask"
37+
3638
@cached_property
3739
def shape(self):
3840
return self.forward.shape[:-1] + (np.count_nonzero(self.mask),)
@@ -67,8 +69,13 @@ def _get_tuple(self, index):
6769
result = apply_index_to_slices_changes(result, changes)
6870
return result
6971

72+
def collect_supporting_arrays(self, collected, *path):
73+
super().collect_supporting_arrays(collected, *path)
74+
collected.append((path, self.mask_name, self.mask))
75+
7076

7177
class Thinning(Masked):
78+
7279
def __init__(self, forward, thinning, method):
7380
self.thinning = thinning
7481
self.method = method
@@ -110,6 +117,7 @@ def subclass_metadata_specific(self):
110117

111118

112119
class Cropping(Masked):
120+
113121
def __init__(self, forward, area):
114122
from ..data import open_dataset
115123

src/anemoi/datasets/data/misc.py

+1
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _auto_adjust(datasets, kwargs):
270270

271271

272272
def _open_dataset(*args, **kwargs):
273+
273274
sets = []
274275
for a in args:
275276
sets.append(_open(a))

src/anemoi/datasets/data/stores.py

+6
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,12 @@ def get_dataset_names(self, names):
344344
name, _ = os.path.splitext(os.path.basename(self.path))
345345
names.add(name)
346346

347+
def collect_supporting_arrays(self, collected, *path):
348+
pass
349+
350+
def collect_input_sources(self, collected):
351+
pass
352+
347353

348354
class ZarrWithMissingDates(Zarr):
349355
"""A zarr dataset with missing dates."""

src/anemoi/datasets/data/subset.py

+2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def dates(self):
135135
@cached_property
136136
def frequency(self):
137137
dates = self.dates
138+
if len(dates) < 2:
139+
raise ValueError(f"Cannot determine frequency of a subset with less than two dates ({self.dates}).")
138140
return frequency_to_timedelta(dates[1].astype(object) - dates[0].astype(object))
139141

140142
def source(self, index):

0 commit comments

Comments
 (0)