Skip to content

Commit

Permalink
upload with ssh (#94)
Browse files Browse the repository at this point in the history
* add copy to ssh target
  • Loading branch information
floriankrb authored Nov 14, 2024
1 parent 6cb6689 commit 4c0213e
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 72 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Keep it human-readable, your future self will thank you!

### Changed

- Upload with ssh (experimental)
- Remove upstream dependencies from downstream-ci workflow (temporary) (#83)
- ci: pin python versions to 3.9 ... 3.12 for checks (#93)
- Fix `__version__` import in init
Expand Down
101 changes: 34 additions & 67 deletions src/anemoi/datasets/commands/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@

import logging
import os
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed

import tqdm
from anemoi.utils.s3 import download
from anemoi.utils.s3 import upload
from anemoi.utils.remote import Transfer
from anemoi.utils.remote import TransferMethodNotImplementedError

from . import Command

Expand All @@ -29,54 +28,7 @@
isatty = False


class S3Downloader:
def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs):
self.source = source
self.target = target
self.transfers = transfers
self.overwrite = overwrite
self.resume = resume
self.verbosity = verbosity

def run(self):
if self.target == ".":
self.target = os.path.basename(self.source)

if self.overwrite and os.path.exists(self.target):
LOG.info(f"Deleting {self.target}")
shutil.rmtree(self.target)

download(
self.source + "/" if not self.source.endswith("/") else self.source,
self.target,
overwrite=self.overwrite,
resume=self.resume,
verbosity=self.verbosity,
threads=self.transfers,
)


class S3Uploader:
def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs):
self.source = source
self.target = target
self.transfers = transfers
self.overwrite = overwrite
self.resume = resume
self.verbosity = verbosity

def run(self):
upload(
self.source,
self.target,
overwrite=self.overwrite,
resume=self.resume,
verbosity=self.verbosity,
threads=self.transfers,
)


class DefaultCopier:
class ZarrCopier:
def __init__(self, source, target, transfers, block_size, overwrite, resume, verbosity, nested, rechunk, **kwargs):
self.source = source
self.target = target
Expand All @@ -90,6 +42,14 @@ def __init__(self, source, target, transfers, block_size, overwrite, resume, ver

self.rechunking = rechunk.split(",") if rechunk else []

source_is_ssh = self.source.startswith("ssh://")
target_is_ssh = self.target.startswith("ssh://")

if source_is_ssh or target_is_ssh:
if self.rechunk:
raise NotImplementedError("Rechunking with SSH not implemented.")
assert NotImplementedError("SSH not implemented.")

def _store(self, path, nested=False):
if nested:
import zarr
Expand Down Expand Up @@ -337,26 +297,33 @@ def run(self, args):
if args.source == args.target:
raise ValueError("Source and target are the same.")

kwargs = vars(args)

if args.overwrite and args.resume:
raise ValueError("Cannot use --overwrite and --resume together.")

source_in_s3 = args.source.startswith("s3://")
target_in_s3 = args.target.startswith("s3://")

copier = None

if args.rechunk or (source_in_s3 and target_in_s3):
copier = DefaultCopier(**kwargs)
else:
if source_in_s3:
copier = S3Downloader(**kwargs)

if target_in_s3:
copier = S3Uploader(**kwargs)

if not args.rechunk:
# rechunking is only supported for ZARR datasets, it is implemented in this package
try:
if args.source.startswith("s3://") and not args.source.endswith("/"):
args.source = args.source + "/"
copier = Transfer(
args.source,
args.target,
overwrite=args.overwrite,
resume=args.resume,
verbosity=args.verbosity,
threads=args.transfers,
)
copier.run()
return
except TransferMethodNotImplementedError:
# DataTransfer relies on anemoi-utils which is agnostic to the source and target format
# it transfers file and folders, ignoring that it is zarr data
# if it is not implemented, we fallback to the ZarrCopier
pass

copier = ZarrCopier(**vars(args))
copier.run()
return


class Copy(CopyMixin, Command):
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class S3Store(ReadOnlyStore):
"""

def __init__(self, url, region=None):
from anemoi.utils.s3 import s3_client
from anemoi.utils.remote.s3 import s3_client

_, _, self.bucket, self.key = url.split("/", 3)
self.s3 = s3_client(self.bucket, region=region)
Expand Down
4 changes: 2 additions & 2 deletions tests/create/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def test_run(name):
# reference_path = os.path.join(HERE, name + "-reference.zarr")
s3_uri = TEST_DATA_ROOT + "/" + name + ".zarr"
# if not os.path.exists(reference_path):
# from anemoi.utils.s3 import download as s3_download
# s3_download(s3_uri + '/', reference_path, overwrite=True)
# from anemoi.utils.remote import transfer
# transfer(s3_uri + '/', reference_path, overwrite=True)

Comparer(name, output_path=output, reference_path=s3_uri).compare()
# Comparer(name, output_path=output, reference_path=reference_path).compare()
Expand Down
6 changes: 4 additions & 2 deletions tools/upload-sample-dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import os

from anemoi.utils.s3 import upload
from anemoi.utils.remote import transfer

LOG = logging.getLogger(__name__)

Expand All @@ -38,6 +38,8 @@
target = args.target
bucket = args.bucket

assert os.path.exists(source), f"Source {source} does not exist"

if not target.startswith("s3://"):
if target.startswith("/"):
target = target[1:]
Expand All @@ -46,5 +48,5 @@
target = os.path.join(bucket, target)

LOG.info(f"Uploading {source} to {target}")
upload(source, target, overwrite=args.overwrite)
transfer(source, target, overwrite=args.overwrite)
LOG.info("Upload complete")

0 comments on commit 4c0213e

Please sign in to comment.