Skip to content

Commit ada96e9

Browse files
authored
feat: accept hyphens in factory names (#116)
1 parent e43586b commit ada96e9

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

src/anemoi/utils/registry.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,17 @@ class Registry:
8888
The package name.
8989
key : str, optional
9090
The key to use for the registry, by default "_type".
91+
api_version : str, optional
92+
The API version, by default '1.0.0'.
9193
"""
9294

93-
def __init__(self, package: str, key: str = "_type"):
95+
def __init__(self, package: str, key: str = "_type", api_version: str = "1.0.0"):
9496
self.package = package
9597
self.__registered = {}
9698
self._sources = {}
9799
self.kind = package.split(".")[-1]
98100
self.key = key
101+
self.api_version = api_version
99102
_BY_KIND[self.kind] = self
100103

101104
@classmethod
@@ -133,6 +136,9 @@ def register(
133136
Wrapper, optional
134137
A wrapper if the factory is None, otherwise None.
135138
"""
139+
140+
name = name.replace("_", "-")
141+
136142
if factory is None:
137143
# This happens when the @register decorator is used
138144
return Wrapper(name, self)
@@ -177,6 +183,9 @@ def is_registered(self, name: str) -> bool:
177183
bool
178184
Whether the factory is registered.
179185
"""
186+
187+
name = name.replace("_", "-")
188+
180189
ok = name in self.factories
181190
if not ok:
182191
LOG.error(f"Cannot find '{name}' in {self.package}")
@@ -199,6 +208,9 @@ def lookup(self, name: str, *, return_none: bool = False) -> Optional[Callable]:
199208
Callable, optional
200209
The factory if found, otherwise None.
201210
"""
211+
212+
name = name.replace("_", "-")
213+
202214
if return_none:
203215
return self.factories.get(name)
204216

@@ -288,6 +300,9 @@ def create(self, name: str, *args: Any, **kwargs: Any) -> Any:
288300
Any
289301
The created instance.
290302
"""
303+
304+
name = name.replace("_", "-")
305+
291306
factory = self.lookup(name)
292307
return factory(*args, **kwargs)
293308

tests/test_remote.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
import os
99
import shutil
10+
import sys
1011

1112
import pytest
1213

1314
from anemoi.utils.remote import TransferMethodNotImplementedError
1415
from anemoi.utils.remote import _find_transfer_class
1516
from anemoi.utils.remote import transfer
17+
from anemoi.utils.testing import packages_installed
1618

1719
IN_CI = (os.environ.get("GITHUB_WORKFLOW") is not None) or (os.environ.get("IN_CI_HPC") is not None)
1820

@@ -112,6 +114,7 @@ def test_transfer_find_none(source: str, target: str) -> None:
112114

113115

114116
@pytest.mark.skipif(IN_CI, reason="Test requires access to S3")
117+
@pytest.mark.skipif(not packages_installed("boto3"), reason="boto3 is not installed")
115118
def test_transfer_zarr_s3_to_local(tmpdir: pytest.TempPathFactory) -> None:
116119
"""Test transferring a Zarr file from S3 to local.
117120
@@ -132,6 +135,7 @@ def test_transfer_zarr_s3_to_local(tmpdir: pytest.TempPathFactory) -> None:
132135

133136

134137
@pytest.mark.skipif(IN_CI, reason="Test requires access to S3")
138+
@pytest.mark.skipif(not packages_installed("boto3"), reason="boto3 is not installed")
135139
def test_transfer_zarr_local_to_s3(tmpdir: pytest.TempPathFactory) -> None:
136140
"""Test transferring a Zarr file from local to S3.
137141
@@ -193,6 +197,7 @@ def compare(local1: str, local2: str) -> None:
193197

194198

195199
@pytest.mark.skipif(IN_CI, reason="Test requires access to S3")
200+
@pytest.mark.skipif(not packages_installed("boto3"), reason="boto3 is not installed")
196201
@pytest.mark.parametrize("path", ["directory/", "file"])
197202
def test_transfer_local_to_s3_to_local(path: str) -> None:
198203
"""Test transferring a file or directory from local to S3 and back to local.
@@ -224,6 +229,7 @@ def test_transfer_local_to_s3_to_local(path: str) -> None:
224229

225230

226231
@pytest.mark.skipif(IN_CI, reason="Test requires ssh access to localhost")
232+
@pytest.mark.skipif(sys.platform == "darwin", reason="Does not work on MacOS")
227233
@pytest.mark.parametrize("path", ["directory", "file"])
228234
@pytest.mark.parametrize("temporary_target", [True, False])
229235
def test_transfer_local_to_ssh(path: str, temporary_target: bool) -> None:

0 commit comments

Comments
 (0)