Skip to content

Commit 71a539b

Browse files
committed
wip
1 parent 771629b commit 71a539b

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

src/anemoi/datasets/create/sources/accumulate.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
LOG = logging.getLogger(__name__)
3131

3232

33-
def _prep_request(request: dict[str, Any], interval_class: type[itv.IntervalsCollection]) -> dict[str, Any]:
33+
def _prep_request(request: dict[str, Any]) -> dict[str, Any]:
3434
request = deepcopy(request)
3535

3636
param = request.pop("param")
@@ -61,7 +61,7 @@ class Accumulator:
6161

6262
def __init__(
6363
self,
64-
interval_class: type[itv.IntervalsCollection],
64+
source_request: dict[str, Any],
6565
valid_date: datetime.datetime,
6666
user_accumulation_period: datetime.timedelta,
6767
data_accumulation_period: datetime.timedelta,
@@ -98,9 +98,10 @@ def __init__(
9898
self.key = {k: v for k, v in kwargs.items() if k in ["param", "level", "levelist", "number"]}
9999

100100
# instantiate IntervalsCollection object
101-
#if interval_class != itv.DefaultIntervalsCollection:
101+
# if interval_class != itv.DefaultIntervalsCollection:
102102
# LOG.warning("Non-default data IntervalsCollection (e.g MARS): ignoring data_accumulation_period")
103103
# data_accumulation_period = frequency_to_timedelta("1h") # only to ensure compatibility
104+
interval_class = itv.find_IntervalsCollection_class(source_request)
104105
self.interval_coll = interval_class(
105106
self.valid_date, user_accumulation_period, data_accumulation_period, **kwargs
106107
)
@@ -239,10 +240,9 @@ def _compute_accumulations(
239240
240241
"""
241242

242-
# interval_coll class depends on data source ; split between Era-like interval collections and Default ones
243-
interval_class = itv.find_IntervalsCollection_class(source_request)
243+
print("💬 source_request:", source_request)
244244

245-
main_request, param, number, additional = _prep_request(source_request, interval_class)
245+
main_request, param, number, additional = _prep_request(source_request)
246246

247247
# building accumulators
248248
accumulators = []
@@ -255,7 +255,7 @@ def _compute_accumulations(
255255
for n in number:
256256
accumulators.append(
257257
Accumulator(
258-
interval_class,
258+
source_request,
259259
valid_date,
260260
user_accumulation_period=user_accumulation_period,
261261
data_accumulation_period=data_accumulation_period,
@@ -274,7 +274,9 @@ def _compute_accumulations(
274274
for r in a.requests:
275275
r = {**main_request, **r}
276276
requests.append(r)
277+
print(f"💬 Accumulator {a} needs request: {r}")
277278

279+
print(f"💬 Creating source '{source_name}' with {len(requests)} requests for accumulation")
278280
source = context.create_source(
279281
{
280282
source_name: dict(
@@ -330,7 +332,7 @@ def _execute(
330332
dates: list[datetime.datetime],
331333
source: Any,
332334
period,
333-
data_accumulation_period="1h",
335+
data_accumulation_period=None,
334336
) -> Any:
335337
"""Accumulation source callable function.
336338
Read the recipe for accumulation in the request dictionary, check main arguments and call computation.
@@ -354,7 +356,9 @@ def _execute(
354356
if "accumulation_period" in source:
355357
raise ValueError("'accumulation_period' should be define outside source for accumulate action as 'period'")
356358
user_accumulation_period = frequency_to_timedelta(period)
357-
data_accumulation_period = frequency_to_timedelta(data_accumulation_period)
359+
data_accumulation_period = (
360+
frequency_to_timedelta(data_accumulation_period) if data_accumulation_period is not None else None
361+
)
358362

359363
source_request = source
360364

@@ -363,10 +367,6 @@ def _execute(
363367

364368
source_name, source_request = next(iter(source.items()))
365369
source_request = source_request.copy()
366-
assert "param" in source_request, (
367-
"param should be defined inside source for accumulate action",
368-
source_request,
369-
)
370370

371371
return _compute_accumulations(
372372
context,

src/anemoi/datasets/create/sources/accumulation_utils/intervals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def find_matching_interval(self, field: Any) -> Interval | None:
196196
def todo(self):
197197
if self._todo is None:
198198
self._todo = set([p.time_request for p in self._intervals])
199-
self._len = len(keys)
199+
self._len = len(self._todo)
200200
self._done = set()
201201
assert self._len == len(self._todo), (self._len, len(self._todo))
202202

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
dates:
2+
frequency: 6h
3+
start: 2021-01-10 18:00:00
4+
end: 2021-01-12 12:00:00
5+
6+
# https://apps.ecmwf.int/mars-catalogue/?stream=enda&levtype=sfc&expver=1&month=aug&year=2020&date=2020-08-16&type=fc&class=ea
7+
# time : 6 18
8+
# step : 0/to/18/by/3
9+
10+
input:
11+
accumulate:
12+
period: 6h
13+
data_accumulation_period: 1h
14+
source:
15+
grib:
16+
path: accumulate-grib-index-meteo-france/input.grib
17+
18+
19+
checks:
20+
none: {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dates:
2+
frequency: 6h
3+
start: 2021-01-02 03:00:00
4+
end: 2021-01-02 12:00:00
5+
6+
# https://apps.ecmwf.int/mars-catalogue/?stream=enda&levtype=sfc&expver=1&month=aug&year=2020&date=2020-08-16&type=fc&class=ea
7+
# time : 6 18
8+
# step : 0/to/18/by/3
9+
10+
input:
11+
accumulate:
12+
period: 6h
13+
data_accumulation_period: 1h
14+
source:
15+
grib-index:
16+
indexdb: index.db
17+
levtype: sfc
18+
param: [ tp ]
19+
20+
21+
checks:
22+
none: {}

tests/create/test_sources.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,13 @@ def test_accumulate_grib_index(get_test_data: callable) -> None:
149149
"pipe": [
150150
{
151151
"accumulate": {
152+
"period": 24,
153+
"data_accumulation_period": "1h",
152154
"source": {
153155
"grib-index": {
154156
"indexdb": os.path.join(path_db, "grib-index-accumulate-tp.db"),
155157
"levtype": "sfc",
156158
"param": ["tp"],
157-
"accumulation_period": 3,
158159
},
159160
},
160161
}

0 commit comments

Comments
 (0)