Skip to content

Commit 4c0213e

Browse files
authored
upload with ssh (#94)
* add copy to ssh target
1 parent 6cb6689 commit 4c0213e

File tree

5 files changed

+42
-72
lines changed

5 files changed

+42
-72
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Keep it human-readable, your future self will thank you!
3434

3535
### Changed
3636

37+
- Upload with ssh (experimental)
3738
- Remove upstream dependencies from downstream-ci workflow (temporary) (#83)
3839
- ci: pin python versions to 3.9 ... 3.12 for checks (#93)
3940
- Fix `__version__` import in init

src/anemoi/datasets/commands/copy.py

Lines changed: 34 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010

1111
import logging
1212
import os
13-
import shutil
1413
import sys
1514
from concurrent.futures import ThreadPoolExecutor
1615
from concurrent.futures import as_completed
1716

1817
import tqdm
19-
from anemoi.utils.s3 import download
20-
from anemoi.utils.s3 import upload
18+
from anemoi.utils.remote import Transfer
19+
from anemoi.utils.remote import TransferMethodNotImplementedError
2120

2221
from . import Command
2322

@@ -29,54 +28,7 @@
2928
isatty = False
3029

3130

32-
class S3Downloader:
33-
def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs):
34-
self.source = source
35-
self.target = target
36-
self.transfers = transfers
37-
self.overwrite = overwrite
38-
self.resume = resume
39-
self.verbosity = verbosity
40-
41-
def run(self):
42-
if self.target == ".":
43-
self.target = os.path.basename(self.source)
44-
45-
if self.overwrite and os.path.exists(self.target):
46-
LOG.info(f"Deleting {self.target}")
47-
shutil.rmtree(self.target)
48-
49-
download(
50-
self.source + "/" if not self.source.endswith("/") else self.source,
51-
self.target,
52-
overwrite=self.overwrite,
53-
resume=self.resume,
54-
verbosity=self.verbosity,
55-
threads=self.transfers,
56-
)
57-
58-
59-
class S3Uploader:
60-
def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs):
61-
self.source = source
62-
self.target = target
63-
self.transfers = transfers
64-
self.overwrite = overwrite
65-
self.resume = resume
66-
self.verbosity = verbosity
67-
68-
def run(self):
69-
upload(
70-
self.source,
71-
self.target,
72-
overwrite=self.overwrite,
73-
resume=self.resume,
74-
verbosity=self.verbosity,
75-
threads=self.transfers,
76-
)
77-
78-
79-
class DefaultCopier:
31+
class ZarrCopier:
8032
def __init__(self, source, target, transfers, block_size, overwrite, resume, verbosity, nested, rechunk, **kwargs):
8133
self.source = source
8234
self.target = target
@@ -90,6 +42,14 @@ def __init__(self, source, target, transfers, block_size, overwrite, resume, ver
9042

9143
self.rechunking = rechunk.split(",") if rechunk else []
9244

45+
source_is_ssh = self.source.startswith("ssh://")
46+
target_is_ssh = self.target.startswith("ssh://")
47+
48+
if source_is_ssh or target_is_ssh:
49+
if self.rechunk:
50+
raise NotImplementedError("Rechunking with SSH not implemented.")
51+
assert NotImplementedError("SSH not implemented.")
52+
9353
def _store(self, path, nested=False):
9454
if nested:
9555
import zarr
@@ -337,26 +297,33 @@ def run(self, args):
337297
if args.source == args.target:
338298
raise ValueError("Source and target are the same.")
339299

340-
kwargs = vars(args)
341-
342300
if args.overwrite and args.resume:
343301
raise ValueError("Cannot use --overwrite and --resume together.")
344302

345-
source_in_s3 = args.source.startswith("s3://")
346-
target_in_s3 = args.target.startswith("s3://")
347-
348-
copier = None
349-
350-
if args.rechunk or (source_in_s3 and target_in_s3):
351-
copier = DefaultCopier(**kwargs)
352-
else:
353-
if source_in_s3:
354-
copier = S3Downloader(**kwargs)
355-
356-
if target_in_s3:
357-
copier = S3Uploader(**kwargs)
358-
303+
if not args.rechunk:
304+
# rechunking is only supported for ZARR datasets, it is implemented in this package
305+
try:
306+
if args.source.startswith("s3://") and not args.source.endswith("/"):
307+
args.source = args.source + "/"
308+
copier = Transfer(
309+
args.source,
310+
args.target,
311+
overwrite=args.overwrite,
312+
resume=args.resume,
313+
verbosity=args.verbosity,
314+
threads=args.transfers,
315+
)
316+
copier.run()
317+
return
318+
except TransferMethodNotImplementedError:
319+
# DataTransfer relies on anemoi-utils which is agnostic to the source and target format
320+
# it transfers file and folders, ignoring that it is zarr data
321+
# if it is not implemented, we fallback to the ZarrCopier
322+
pass
323+
324+
copier = ZarrCopier(**vars(args))
359325
copier.run()
326+
return
360327

361328

362329
class Copy(CopyMixin, Command):

src/anemoi/datasets/data/stores.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class S3Store(ReadOnlyStore):
7171
"""
7272

7373
def __init__(self, url, region=None):
74-
from anemoi.utils.s3 import s3_client
74+
from anemoi.utils.remote.s3 import s3_client
7575

7676
_, _, self.bucket, self.key = url.split("/", 3)
7777
self.s3 = s3_client(self.bucket, region=region)

tests/create/test_create.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ def test_run(name):
273273
# reference_path = os.path.join(HERE, name + "-reference.zarr")
274274
s3_uri = TEST_DATA_ROOT + "/" + name + ".zarr"
275275
# if not os.path.exists(reference_path):
276-
# from anemoi.utils.s3 import download as s3_download
277-
# s3_download(s3_uri + '/', reference_path, overwrite=True)
276+
# from anemoi.utils.remote import transfer
277+
# transfer(s3_uri + '/', reference_path, overwrite=True)
278278

279279
Comparer(name, output_path=output, reference_path=s3_uri).compare()
280280
# Comparer(name, output_path=output, reference_path=reference_path).compare()

tools/upload-sample-dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import logging
2121
import os
2222

23-
from anemoi.utils.s3 import upload
23+
from anemoi.utils.remote import transfer
2424

2525
LOG = logging.getLogger(__name__)
2626

@@ -38,6 +38,8 @@
3838
target = args.target
3939
bucket = args.bucket
4040

41+
assert os.path.exists(source), f"Source {source} does not exist"
42+
4143
if not target.startswith("s3://"):
4244
if target.startswith("/"):
4345
target = target[1:]
@@ -46,5 +48,5 @@
4648
target = os.path.join(bucket, target)
4749

4850
LOG.info(f"Uploading {source} to {target}")
49-
upload(source, target, overwrite=args.overwrite)
51+
transfer(source, target, overwrite=args.overwrite)
5052
LOG.info("Upload complete")

0 commit comments

Comments
 (0)