Skip to content

Commit 9e454e0

Browse files
committed
feat: allow using json files to define datasets
1 parent 4d06077 commit 9e454e0

File tree

3 files changed

+66
-23
lines changed

3 files changed

+66
-23
lines changed

src/anemoi/datasets/data/misc.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
import calendar
1212
import datetime
13+
import json
1314
import logging
15+
import os
1416
from pathlib import PurePath
1517

1618
import numpy as np
@@ -195,7 +197,6 @@ def _concat_or_join(datasets, kwargs):
195197

196198
def _open(a):
197199
from .stores import Zarr
198-
from .stores import zarr_lookup
199200

200201
if isinstance(a, Dataset):
201202
return a.mutate()
@@ -204,7 +205,34 @@ def _open(a):
204205
return Zarr(a).mutate()
205206

206207
if isinstance(a, str):
207-
return Zarr(zarr_lookup(a)).mutate()
208+
from .stores import DATASET_FINDER
209+
210+
tried = []
211+
for name in DATASET_FINDER.ls(a):
212+
tried.append(name)
213+
214+
if name.endswith(".json"):
215+
DATASET_FINDER.log_open(a, name)
216+
if not os.path.exists(name):
217+
continue
218+
219+
obj = json.load(open(name))
220+
if isinstance(obj, dict):
221+
return _open_dataset(**obj).mutate()
222+
elif isinstance(obj, (list, tuple)):
223+
return _open_dataset(*obj).mutate()
224+
raise ValueError(f"Invalid content: {type(obj)} in {name}")
225+
226+
if name.endswith(".zarr") or name.endswith(".zip"):
227+
try:
228+
DATASET_FINDER.log_open(a, name)
229+
return Zarr(name).mutate()
230+
except zarr.errors.PathNotFoundError:
231+
pass
232+
233+
raise ValueError(f"Unsupported file: {name}")
234+
235+
raise ValueError(f"Cannot find a dataset that matched '{a}'. Tried: {tried}")
208236

209237
if isinstance(a, PurePath):
210238
return _open(str(a)).mutate()

src/anemoi/datasets/data/stores.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import zarr
19+
from anemoi.utils.dates import frequency_to_string
1920
from anemoi.utils.dates import frequency_to_timedelta
2021

2122
from . import MissingDateError
@@ -289,8 +290,6 @@ def statistics_tendencies(self, delta=None):
289290
delta = self.frequency
290291
if isinstance(delta, int):
291292
delta = f"{delta}h"
292-
from anemoi.utils.dates import frequency_to_string
293-
from anemoi.utils.dates import frequency_to_timedelta
294293

295294
delta = frequency_to_timedelta(delta)
296295
delta = frequency_to_string(delta)
@@ -450,36 +449,47 @@ def label(self):
450449
return "zarr*"
451450

452451

453-
QUIET = set()
452+
class DatasetFinder:
453+
QUIET = set()
454454

455+
@cached_property
456+
def _config(self):
457+
return load_config()["datasets"]
455458

456-
def zarr_lookup(name, fail=True):
459+
def ls(self, name):
460+
if name in self._config["named"]:
461+
yield self._config["named"][name]
462+
return
463+
464+
if name.endswith(".zip") or name.endswith(".zarr") or name.endswith(".json"):
465+
yield name
466+
return
457467

458-
if name.endswith(".zarr") or name.endswith(".zip"):
459-
return name
468+
for location in self._config["path"]:
469+
if not location.endswith("/"):
470+
location += "/"
460471

461-
config = load_config()["datasets"]
472+
yield location + name + ".json"
473+
yield location + name + ".zarr"
462474

463-
if name in config["named"]:
464-
if name not in QUIET:
465-
LOG.info("Opening `%s` as `%s`", name, config["named"][name])
466-
QUIET.add(name)
467-
return config["named"][name]
475+
def log_open(self, name, full):
476+
if name not in self.QUIET:
477+
LOG.info("Opening `%s` as `%s`", name, full)
478+
self.QUIET.add(name)
479+
480+
481+
DATASET_FINDER = DatasetFinder()
482+
483+
484+
def zarr_lookup(name, fail=True):
468485

469486
tried = []
470-
for location in config["path"]:
471-
if not location.endswith("/"):
472-
location += "/"
473-
full = location + name + ".zarr"
487+
for full in DATASET_FINDER.ls(name):
474488
tried.append(full)
475489
try:
490+
DATASET_FINDER.log_open(name, full)
476491
z = open_zarr(full, dont_fail=True)
477492
if z is not None:
478-
# Cache for next time
479-
config["named"][name] = full
480-
if name not in QUIET:
481-
LOG.info("Opening `%s` as `%s`", name, full)
482-
QUIET.add(name)
483493
return full
484494
except zarr.errors.PathNotFoundError:
485495
pass

tests/test_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010

1111
import datetime
12+
import os
1213
from functools import cache
1314
from functools import wraps
1415
from unittest.mock import patch
@@ -156,6 +157,10 @@ def create_zarr(
156157
def zarr_from_str(name, mode):
157158
# Format: test-2021-2021-6h-o96-abcd-0
158159

160+
if name.endswith(".zarr"):
161+
name = os.path.basename(name)
162+
name = os.path.splitext(name)[0]
163+
159164
args = dict(
160165
test="test",
161166
start=2021,

0 commit comments

Comments
 (0)