From 4c0213e2e1a76a0ace06c60c6afd39223cdd38a1 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 14 Nov 2024 09:18:53 +0100 Subject: [PATCH] upload with ssh (#94) * add copy to ssh target --- CHANGELOG.md | 1 + src/anemoi/datasets/commands/copy.py | 101 +++++++++------------------ src/anemoi/datasets/data/stores.py | 2 +- tests/create/test_create.py | 4 +- tools/upload-sample-dataset.py | 6 +- 5 files changed, 42 insertions(+), 72 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c3c63faa..e66e2d098 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ Keep it human-readable, your future self will thank you! ### Changed +- Upload with ssh (experimental) - Remove upstream dependencies from downstream-ci workflow (temporary) (#83) - ci: pin python versions to 3.9 ... 3.12 for checks (#93) - Fix `__version__` import in init diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index 1ca9aef8b..9d66b2fc6 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -10,14 +10,13 @@ import logging import os -import shutil import sys from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed import tqdm -from anemoi.utils.s3 import download -from anemoi.utils.s3 import upload +from anemoi.utils.remote import Transfer +from anemoi.utils.remote import TransferMethodNotImplementedError from . import Command @@ -29,54 +28,7 @@ isatty = False -class S3Downloader: - def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs): - self.source = source - self.target = target - self.transfers = transfers - self.overwrite = overwrite - self.resume = resume - self.verbosity = verbosity - - def run(self): - if self.target == ".": - self.target = os.path.basename(self.source) - - if self.overwrite and os.path.exists(self.target): - LOG.info(f"Deleting {self.target}") - shutil.rmtree(self.target) - - download( - self.source + "/" if not self.source.endswith("/") else self.source, - self.target, - overwrite=self.overwrite, - resume=self.resume, - verbosity=self.verbosity, - threads=self.transfers, - ) - - -class S3Uploader: - def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs): - self.source = source - self.target = target - self.transfers = transfers - self.overwrite = overwrite - self.resume = resume - self.verbosity = verbosity - - def run(self): - upload( - self.source, - self.target, - overwrite=self.overwrite, - resume=self.resume, - verbosity=self.verbosity, - threads=self.transfers, - ) - - -class DefaultCopier: +class ZarrCopier: def __init__(self, source, target, transfers, block_size, overwrite, resume, verbosity, nested, rechunk, **kwargs): self.source = source self.target = target @@ -90,6 +42,14 @@ def __init__(self, source, target, transfers, block_size, overwrite, resume, ver self.rechunking = rechunk.split(",") if rechunk else [] + source_is_ssh = self.source.startswith("ssh://") + target_is_ssh = self.target.startswith("ssh://") + + if source_is_ssh or target_is_ssh: + if self.rechunk: + raise NotImplementedError("Rechunking with SSH not implemented.") + assert NotImplementedError("SSH not implemented.") + def _store(self, path, nested=False): if nested: import zarr @@ -337,26 +297,33 @@ def run(self, args): if args.source == args.target: raise ValueError("Source and target are the same.") - kwargs = vars(args) - if args.overwrite and args.resume: raise ValueError("Cannot use --overwrite and --resume together.") - source_in_s3 = args.source.startswith("s3://") - target_in_s3 = args.target.startswith("s3://") - - copier = None - - if args.rechunk or (source_in_s3 and target_in_s3): - copier = DefaultCopier(**kwargs) - else: - if source_in_s3: - copier = S3Downloader(**kwargs) - - if target_in_s3: - copier = S3Uploader(**kwargs) - + if not args.rechunk: + # rechunking is only supported for ZARR datasets, it is implemented in this package + try: + if args.source.startswith("s3://") and not args.source.endswith("/"): + args.source = args.source + "/" + copier = Transfer( + args.source, + args.target, + overwrite=args.overwrite, + resume=args.resume, + verbosity=args.verbosity, + threads=args.transfers, + ) + copier.run() + return + except TransferMethodNotImplementedError: + # DataTransfer relies on anemoi-utils which is agnostic to the source and target format + # it transfers file and folders, ignoring that it is zarr data + # if it is not implemented, we fallback to the ZarrCopier + pass + + copier = ZarrCopier(**vars(args)) copier.run() + return class Copy(CopyMixin, Command): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 19246bc0d..f2c69ae1c 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -71,7 +71,7 @@ class S3Store(ReadOnlyStore): """ def __init__(self, url, region=None): - from anemoi.utils.s3 import s3_client + from anemoi.utils.remote.s3 import s3_client _, _, self.bucket, self.key = url.split("/", 3) self.s3 = s3_client(self.bucket, region=region) diff --git a/tests/create/test_create.py b/tests/create/test_create.py index d29497a05..ab612dbf0 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -273,8 +273,8 @@ def test_run(name): # reference_path = os.path.join(HERE, name + "-reference.zarr") s3_uri = TEST_DATA_ROOT + "/" + name + ".zarr" # if not os.path.exists(reference_path): - # from anemoi.utils.s3 import download as s3_download - # s3_download(s3_uri + '/', reference_path, overwrite=True) + # from anemoi.utils.remote import transfer + # transfer(s3_uri + '/', reference_path, overwrite=True) Comparer(name, output_path=output, reference_path=s3_uri).compare() # Comparer(name, output_path=output, reference_path=reference_path).compare() diff --git a/tools/upload-sample-dataset.py b/tools/upload-sample-dataset.py index 586f7fdf0..67d8ebd09 100755 --- a/tools/upload-sample-dataset.py +++ b/tools/upload-sample-dataset.py @@ -20,7 +20,7 @@ import logging import os -from anemoi.utils.s3 import upload +from anemoi.utils.remote import transfer LOG = logging.getLogger(__name__) @@ -38,6 +38,8 @@ target = args.target bucket = args.bucket +assert os.path.exists(source), f"Source {source} does not exist" + if not target.startswith("s3://"): if target.startswith("/"): target = target[1:] @@ -46,5 +48,5 @@ target = os.path.join(bucket, target) LOG.info(f"Uploading {source} to {target}") -upload(source, target, overwrite=args.overwrite) +transfer(source, target, overwrite=args.overwrite) LOG.info("Upload complete")