Skip to content

Commit 6cccf2e

Browse files
authored
Feature/xy zip (#85)
Update xy/zip
1 parent 90ccc40 commit 6cccf2e

File tree

5 files changed

+179
-4
lines changed

5 files changed

+179
-4
lines changed

Diff for: CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Keep it human-readable, your future self will thank you!
1414
### Added
1515

1616
- Add anemoi-transform link to documentation
17+
- Control compatibility check in xy/zip
18+
- Add `merge` feature
1719

1820
### Changed
1921

Diff for: src/anemoi/datasets/data/concat.py

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def concat_factory(args, kwargs):
148148

149149
datasets = kwargs.pop("concat")
150150
fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
151+
151152
assert isinstance(datasets, (list, tuple))
152153
assert len(args) == 0
153154

Diff for: src/anemoi/datasets/data/merge.py

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2+
# This software is licensed under the terms of the Apache Licence Version 2.0
3+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4+
# In applying this licence, ECMWF does not waive the privileges and immunities
5+
# granted to it by virtue of its status as an intergovernmental organisation
6+
# nor does it submit to any jurisdiction.
7+
8+
import logging
9+
from functools import cached_property
10+
11+
import numpy as np
12+
13+
from . import MissingDateError
14+
from .debug import Node
15+
from .debug import debug_indexing
16+
from .forwards import Combined
17+
from .indexing import apply_index_to_slices_changes
18+
from .indexing import expand_list_indexing
19+
from .indexing import index_to_slices
20+
from .indexing import update_tuple
21+
from .misc import _auto_adjust
22+
from .misc import _open
23+
24+
LOG = logging.getLogger(__name__)
25+
26+
27+
class Merge(Combined):
28+
def __init__(self, datasets, allow_gaps_in_dates=False):
29+
super().__init__(datasets)
30+
31+
self.allow_gaps_in_dates = allow_gaps_in_dates
32+
33+
dates = dict()
34+
35+
for i, d in enumerate(datasets):
36+
for j, date in enumerate(d.dates):
37+
date = date.astype(object)
38+
if date in dates:
39+
d1 = datasets[dates[date][0]]
40+
d2 = datasets[i]
41+
raise ValueError(f"Duplicate date {date} found in datasets {d1} and {d2}")
42+
dates[date] = (i, j)
43+
44+
all_dates = sorted(dates)
45+
start = all_dates[0]
46+
end = all_dates[-1]
47+
48+
frequency = min(d2 - d1 for d1, d2 in zip(all_dates[:-1], all_dates[1:]))
49+
50+
date = start
51+
indices = []
52+
_dates = []
53+
54+
self._missing_index = len(datasets)
55+
56+
while date <= end:
57+
if date not in dates:
58+
if self.allow_gaps_in_dates:
59+
dates[date] = (self._missing_index, -1)
60+
else:
61+
raise ValueError(
62+
f"merge: date {date} not covered by dataset. Start={start}, end={end}, frequency={frequency}"
63+
)
64+
65+
indices.append(dates[date])
66+
_dates.append(date)
67+
date += frequency
68+
69+
self._dates = np.array(_dates, dtype="datetime64[s]")
70+
self._indices = np.array(indices)
71+
self._frequency = frequency
72+
73+
@property
74+
def dates(self):
75+
return self._dates
76+
77+
@property
78+
def frequency(self):
79+
return self._frequency
80+
81+
@cached_property
82+
def missing(self):
83+
# TODO: optimize
84+
result = set()
85+
86+
for i, (dataset, row) in enumerate(self._indices):
87+
if dataset == self._missing_index:
88+
result.add(i)
89+
continue
90+
91+
if row in self.datasets[dataset].missing:
92+
result.add(i)
93+
94+
return result
95+
96+
def check_same_lengths(self, d1, d2):
97+
# Turned off because we are concatenating along the first axis
98+
pass
99+
100+
def check_same_dates(self, d1, d2):
101+
# Turned off because we are concatenating along the dates axis
102+
pass
103+
104+
def check_compatibility(self, d1, d2):
105+
super().check_compatibility(d1, d2)
106+
self.check_same_sub_shapes(d1, d2, drop_axis=0)
107+
108+
def tree(self):
109+
return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
110+
111+
@debug_indexing
112+
def __getitem__(self, n):
113+
if isinstance(n, tuple):
114+
return self._get_tuple(n)
115+
116+
if isinstance(n, slice):
117+
return self._get_slice(n)
118+
119+
dataset, row = self._indices[n]
120+
121+
if dataset == self._missing_index:
122+
raise MissingDateError(f"Date {self.dates[n]} is missing (index={n})")
123+
124+
return self.datasets[dataset][int(row)]
125+
126+
@debug_indexing
127+
@expand_list_indexing
128+
def _get_tuple(self, index):
129+
index, changes = index_to_slices(index, self.shape)
130+
index, previous = update_tuple(index, 0, slice(None))
131+
result = self._get_slice(previous)
132+
return apply_index_to_slices_changes(result[index], changes)
133+
134+
def _get_slice(self, s):
135+
return np.stack([self[i] for i in range(*s.indices(self._len))])
136+
137+
138+
def merge_factory(args, kwargs):
139+
140+
datasets = kwargs.pop("merge")
141+
142+
assert isinstance(datasets, (list, tuple))
143+
assert len(args) == 0
144+
145+
datasets = [_open(e) for e in datasets]
146+
147+
if len(datasets) == 1:
148+
return datasets[0]._subset(**kwargs)
149+
150+
datasets, kwargs = _auto_adjust(datasets, kwargs)
151+
152+
allow_gaps_in_dates = kwargs.pop("allow_gaps_in_dates", False)
153+
154+
return Merge(datasets, allow_gaps_in_dates=allow_gaps_in_dates)._subset(**kwargs)

Diff for: src/anemoi/datasets/data/misc.py

+6
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,12 @@ def _open_dataset(*args, **kwargs):
302302
assert not sets, sets
303303
return concat_factory(args, kwargs).mutate()
304304

305+
if "merge" in kwargs:
306+
from .merge import merge_factory
307+
308+
assert not sets, sets
309+
return merge_factory(args, kwargs).mutate()
310+
305311
if "ensemble" in kwargs:
306312
from .ensemble import ensemble_factory
307313

Diff for: src/anemoi/datasets/data/xy.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@
1818

1919
class ZipBase(Combined):
2020

21+
def __init__(self, datasets, check_compatibility=True):
22+
self._check_compatibility = check_compatibility
23+
super().__init__(datasets)
24+
2125
def swap_with_parent(self, parent):
2226
new_parents = [parent.clone(ds) for ds in self.datasets]
2327
return self.clone(new_parents)
2428

2529
def clone(self, datasets):
26-
return self.__class__(datasets)
30+
return self.__class__(datasets, check_compatibility=self._check_compatibility)
2731

2832
def tree(self):
29-
return Node(self, [d.tree() for d in self.datasets])
33+
return Node(self, [d.tree() for d in self.datasets], check_compatibility=self._check_compatibility)
3034

3135
def __len__(self):
3236
return min(len(d) for d in self.datasets)
@@ -86,6 +90,10 @@ def resolution(self):
8690
def name_to_index(self):
8791
return tuple(d.name_to_index for d in self.datasets)
8892

93+
def check_compatibility(self, d1, d2):
94+
if self._check_compatibility:
95+
super().check_compatibility(d1, d2)
96+
8997

9098
class Zip(ZipBase):
9199
pass
@@ -110,7 +118,9 @@ def xy_factory(args, kwargs):
110118

111119
assert len(datasets) == 2
112120

113-
return XY(datasets)._subset(**kwargs)
121+
check_compatibility = kwargs.pop("check_compatibility", True)
122+
123+
return XY(datasets, check_compatibility=check_compatibility)._subset(**kwargs)
114124

115125

116126
def zip_factory(args, kwargs):
@@ -122,4 +132,6 @@ def zip_factory(args, kwargs):
122132
datasets = [_open(e) for e in zip]
123133
datasets, kwargs = _auto_adjust(datasets, kwargs)
124134

125-
return Zip(datasets)._subset(**kwargs)
135+
check_compatibility = kwargs.pop("check_compatibility", True)
136+
137+
return Zip(datasets, check_compatibility=check_compatibility)._subset(**kwargs)

0 commit comments

Comments
 (0)