Skip to content

Commit 941e4e7

Browse files
committed
add checkpoints
1 parent b770fc6 commit 941e4e7

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

anemoi/utils/checkpoints.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# (C) Copyright 2024 ECMWF.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
# In applying this licence, ECMWF does not waive the privileges and immunities
6+
# granted to it by virtue of its status as an intergovernmental organisation
7+
# nor does it submit to any jurisdiction.
8+
9+
import json
10+
import logging
11+
import os
12+
import zipfile
13+
14+
LOG = logging.getLogger(__name__)
15+
16+
DEFAULT_NAME = "anemoi-metadata.json"
17+
18+
19+
def load_metadata(path, name=DEFAULT_NAME):
20+
with zipfile.ZipFile(path, "r") as f:
21+
metadata = None
22+
for b in f.namelist():
23+
if os.path.basename(b) == name:
24+
if metadata is not None:
25+
LOG.warning(f"Found two '{name}' if {path}")
26+
metadata = b
27+
28+
if metadata is not None:
29+
with zipfile.ZipFile(path, "r") as f:
30+
return json.load(f.open(metadata, "r"))
31+
else:
32+
raise ValueError(f"Could not find {name} in {path}")
33+
34+
35+
def save_metadata(path, metadata, name=DEFAULT_NAME):
36+
with zipfile.ZipFile(path, "a") as zipf:
37+
base, _ = os.path.splitext(os.path.basename(path))
38+
zipf.writestr(
39+
f"{base}/{name}",
40+
json.dumps(metadata),
41+
)

0 commit comments

Comments
 (0)