Skip to content

Commit 50c39a6

Browse files
committed
fix tests
1 parent 0fd5220 commit 50c39a6

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

src/anemoi/datasets/data/stores.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import os
1414
import tempfile
15+
import threading
1516
import warnings
1617
from functools import cached_property
1718
from typing import Any
@@ -151,6 +152,7 @@ def __init__(self, store: ReadOnlyStore, options: Optional[Dict[str, Any]] = Non
151152
self.reused_objects = 0
152153
self.key_cache = set()
153154
self.path_cache = set()
155+
self.lock = threading.Lock()
154156

155157
self.tmpdir = tempfile.TemporaryDirectory(
156158
prefix="anemoi-datasets-ssd-",
@@ -168,28 +170,30 @@ def __del__(self) -> None:
168170

169171
def __getitem__(self, key: str) -> bytes:
170172

171-
path = os.path.join(self.tmpdir.name, key)
173+
with self.lock:
172174

173-
if key in self.key_cache or os.path.exists(path):
174-
self.key_cache.add(key)
175-
self.reused_objects += 1
176-
return open(path, "rb").read()
175+
path = os.path.join(self.tmpdir.name, key)
176+
177+
if key in self.key_cache or os.path.exists(path):
178+
self.key_cache.add(key)
179+
self.reused_objects += 1
180+
return open(path, "rb").read()
177181

178-
self.copied_objects += 1
179-
value = self.store[key]
182+
self.copied_objects += 1
183+
value = self.store[key]
180184

181-
parent = os.path.dirname(path)
182-
if parent not in self.path_cache:
183-
os.makedirs(parent, exist_ok=True)
184-
self.path_cache.add(parent)
185+
parent = os.path.dirname(path)
186+
if parent not in self.path_cache:
187+
os.makedirs(parent, exist_ok=True)
188+
self.path_cache.add(parent)
185189

186-
with open(path, "wb") as f:
187-
f.write(value)
190+
with open(path, "wb") as f:
191+
f.write(value)
188192

189-
self.total_size += len(value)
190-
self.key_cache.add(key)
193+
self.total_size += len(value)
194+
self.key_cache.add(key)
191195

192-
return value
196+
return value
193197

194198
def __len__(self) -> int:
195199
"""Return the number of items in the store."""
@@ -203,14 +207,17 @@ def __iter__(self) -> iter:
203207
def __contains__(self, key: str) -> bool:
204208
"""Check if the store contains a key."""
205209

206-
if key in self.key_cache:
207-
return True
210+
with self.lock:
208211

209-
path = os.path.join(self.tmpdir.name, key)
210-
if os.path.exists(path):
211-
return True
212+
if key in self.key_cache:
213+
return True
212214

213-
return key in self.store
215+
path = os.path.join(self.tmpdir.name, key)
216+
if os.path.exists(path):
217+
self.key_cache.add(key)
218+
return True
219+
220+
return key in self.store
214221

215222

216223
def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:

tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def mockup_open_zarr(func: Callable) -> Callable:
5757

5858
@wraps(func)
5959
def wrapper(*args, **kwargs):
60-
with patch("zarr.convenience.open", zarr_from_str):
60+
with patch("zarr.open", zarr_from_str):
6161
with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name):
6262
return func(*args, **kwargs)
6363

tests/test_data_gridded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def mockup_open_zarr(func: Callable) -> Callable:
4444

4545
@wraps(func)
4646
def wrapper(*args, **kwargs):
47-
with patch("zarr.convenience.open", zarr_from_str):
47+
with patch("zarr.open", zarr_from_str):
4848
with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name):
4949
return func(*args, **kwargs)
5050

0 commit comments

Comments
 (0)