Skip to content

Commit 781e328

Browse files
committed
fix ear5 accumulate. + breaking change in the recipe for accumlate: ...
1 parent 4c42d0f commit 781e328

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

src/anemoi/datasets/create/sources/accumulate.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def __init__(
101101
if interval_class != itv.DefaultIntervalsCollection:
102102
LOG.warning("Non-default data IntervalsCollection (e.g MARS): ignoring data_accumulation_period")
103103
data_accumulation_period = frequency_to_timedelta("1h") # only to ensure compatibility
104-
self.interval_coll = interval_class(self.valid_date, user_accumulation_period, data_accumulation_period, **kwargs)
104+
self.interval_coll = interval_class(
105+
self.valid_date, user_accumulation_period, data_accumulation_period, **kwargs
106+
)
105107

106108
@property
107109
def requests(self) -> dict:
@@ -270,6 +272,7 @@ def _compute_accumulations(
270272
requests = []
271273
for a in accumulators:
272274
for r in a.requests:
275+
r = {**main_request, **r}
273276
requests.append(r)
274277

275278
source = context.create_source(
@@ -322,7 +325,13 @@ def _compute_accumulations(
322325
class Accumulations2Source(LegacySource):
323326

324327
@staticmethod
325-
def _execute(context: Any, dates: list[datetime.datetime], source: Any) -> Any:
328+
def _execute(
329+
context: Any,
330+
dates: list[datetime.datetime],
331+
source: Any,
332+
accumulation_period="1h",
333+
data_accumulation_period="1h",
334+
) -> Any:
326335
"""Accumulation source callable function.
327336
Read the recipe for accumulation in the request dictionary, check main arguments and call computation.
328337
@@ -342,21 +351,28 @@ def _execute(context: Any, dates: list[datetime.datetime], source: Any) -> Any:
342351
The accumulated data source.
343352
344353
"""
354+
if "accumulation_period" in source:
355+
raise ValueError("'accumulation_period' should be define outside source for accumulate action")
356+
user_accumulation_period = frequency_to_timedelta(accumulation_period)
357+
data_accumulation_period = frequency_to_timedelta(data_accumulation_period)
358+
359+
source_request = source
345360

346361
assert isinstance(source, dict)
347362
assert len(source) == 1
348-
assert "param" not in source, "param should be defined inside source for accumulate action"
349363

350364
source_name, source_request = next(iter(source.items()))
351365
source_request = source_request.copy()
352-
353-
assert "accumulation_period" in source_request, "'accumulation_period' keyword necessary"
366+
assert "param" in source_request, (
367+
"param should be defined inside source for accumulate action",
368+
source_request,
369+
)
354370

355371
return _compute_accumulations(
356372
context,
357373
dates,
358374
source_name,
359375
source_request,
360-
user_accumulation_period=frequency_to_timedelta(source_request.pop("accumulation_period")),
361-
data_accumulation_period=frequency_to_timedelta(source_request.get("data_accumulation_period", "1h")),
376+
user_accumulation_period=user_accumulation_period,
377+
data_accumulation_period=data_accumulation_period,
362378
)

tests/create/accumulation_1.yaml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@ dates:
55

66
input:
77
accumulate:
8-
mars:
9-
expver: "0001"
10-
class: ea
8+
accumulation_period: 6h
9+
source:
10+
mars:
11+
expver: "0001"
12+
class: ea
13+
#stream: oper
1114

12-
stream: enda
15+
#stream: enda
1316

14-
grid: 20./20.
15-
levtype: sfc
16-
param: [ tp, cp ]
17-
accumulation_period: 24
17+
grid: 20./20.
18+
levtype: sfc
19+
param: [ tp, cp ]
1820

1921
checks:
2022
none: {}

0 commit comments

Comments
 (0)