|
1 | 1 | """Miscellaneous Utils for Sky Data
|
2 | 2 | """
|
3 |
| -from typing import Any, Tuple |
| 3 | +from multiprocessing import pool |
| 4 | +import os |
| 5 | +import subprocess |
| 6 | +from pathlib import Path |
| 7 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
4 | 8 | import urllib.parse
|
5 | 9 |
|
| 10 | +from sky import exceptions |
| 11 | +from sky import sky_logging |
6 | 12 | from sky.adaptors import aws, gcp
|
| 13 | +from sky.utils import ux_utils |
7 | 14 |
|
8 | 15 | Client = Any
|
9 | 16 |
|
| 17 | +logger = sky_logging.init_logger(__name__) |
| 18 | + |
10 | 19 |
|
11 | 20 | def split_s3_path(s3_path: str) -> Tuple[str, str]:
|
12 | 21 | """Splits S3 Path into Bucket name and Relative Path to Bucket
|
@@ -69,3 +78,118 @@ def is_cloud_store_url(url):
|
69 | 78 | result = urllib.parse.urlsplit(url)
|
70 | 79 | # '' means non-cloud URLs.
|
71 | 80 | return result.netloc
|
| 81 | + |
| 82 | + |
| 83 | +def _group_files_by_dir( |
| 84 | + source_list: List[str]) -> Tuple[Dict[str, List[str]], List[str]]: |
| 85 | + """Groups a list of paths based on their directory |
| 86 | +
|
| 87 | + Given a list of paths, generates a dict of {dir_name: List[file_name]} |
| 88 | + which groups files with same dir, and a list of dirs in the source_list. |
| 89 | +
|
| 90 | + This is used to optimize uploads by reducing the number of calls to rsync. |
| 91 | + E.g., ['a/b/c.txt', 'a/b/d.txt', 'a/e.txt'] will be grouped into |
| 92 | + {'a/b': ['c.txt', 'd.txt'], 'a': ['e.txt']}, and these three files can be |
| 93 | + uploaded in two rsync calls instead of three. |
| 94 | +
|
| 95 | + Args: |
| 96 | + source_list: List[str]; List of paths to group |
| 97 | + """ |
| 98 | + grouped_files = {} |
| 99 | + dirs = [] |
| 100 | + for source in source_list: |
| 101 | + source = os.path.abspath(os.path.expanduser(source)) |
| 102 | + if os.path.isdir(source): |
| 103 | + dirs.append(source) |
| 104 | + else: |
| 105 | + base_path = os.path.dirname(source) |
| 106 | + file_name = os.path.basename(source) |
| 107 | + if base_path not in grouped_files: |
| 108 | + grouped_files[base_path] = [] |
| 109 | + grouped_files[base_path].append(file_name) |
| 110 | + return grouped_files, dirs |
| 111 | + |
| 112 | + |
| 113 | +def parallel_upload(source_path_list: List[Path], |
| 114 | + filesync_command_generator: Callable[[str, List[str]], str], |
| 115 | + dirsync_command_generator: Callable[[str, str], str], |
| 116 | + bucket_name: str, |
| 117 | + access_denied_message: str, |
| 118 | + create_dirs: bool = False, |
| 119 | + max_concurrent_uploads: Optional[int] = None) -> None: |
| 120 | + """Helper function to run parallel uploads for a list of paths. |
| 121 | +
|
| 122 | + Used by S3Store and GCSStore to run rsync commands in parallel by |
| 123 | + providing appropriate command generators. |
| 124 | +
|
| 125 | + Args: |
| 126 | + source_path_list: List of paths to local files or directories |
| 127 | + filesync_command_generator: Callable that generates rsync command |
| 128 | + for a list of files belonging to the same dir. |
| 129 | + dirsync_command_generator: Callable that generates rsync command |
| 130 | + for a directory. |
| 131 | + access_denied_message: Message to intercept from the underlying |
| 132 | + upload utility when permissions are insufficient. Used in |
| 133 | + exception handling. |
| 134 | + create_dirs: If the local_path is a directory and this is set to |
| 135 | + False, the contents of the directory are directly uploaded to |
| 136 | + root of the bucket. If the local_path is a directory and this is |
| 137 | + set to True, the directory is created in the bucket root and |
| 138 | + contents are uploaded to it. |
| 139 | + max_concurrent_uploads: Maximum number of concurrent threads to use |
| 140 | + to upload files. |
| 141 | + """ |
| 142 | + # Generate gsutil rsync command for files and dirs |
| 143 | + commands = [] |
| 144 | + grouped_files, dirs = _group_files_by_dir(source_path_list) |
| 145 | + # Generate file upload commands |
| 146 | + for dir_path, file_names in grouped_files.items(): |
| 147 | + sync_command = filesync_command_generator(dir_path, file_names) |
| 148 | + commands.append(sync_command) |
| 149 | + # Generate dir upload commands |
| 150 | + for dir_path in dirs: |
| 151 | + if create_dirs: |
| 152 | + dest_dir_name = os.path.basename(dir_path) |
| 153 | + else: |
| 154 | + dest_dir_name = '' |
| 155 | + sync_command = dirsync_command_generator(dir_path, dest_dir_name) |
| 156 | + commands.append(sync_command) |
| 157 | + |
| 158 | + # Run commands in parallel |
| 159 | + with pool.ThreadPool(processes=max_concurrent_uploads) as p: |
| 160 | + p.starmap( |
| 161 | + run_upload_cli, |
| 162 | + zip(commands, [access_denied_message] * len(commands), |
| 163 | + [bucket_name] * len(commands))) |
| 164 | + |
| 165 | + |
| 166 | +def run_upload_cli(command: str, access_denied_message: str, bucket_name: str): |
| 167 | + # TODO(zhwu): Use log_lib.run_with_log() and redirect the output |
| 168 | + # to a log file. |
| 169 | + with subprocess.Popen(command, |
| 170 | + stderr=subprocess.PIPE, |
| 171 | + stdout=subprocess.DEVNULL, |
| 172 | + shell=True) as process: |
| 173 | + stderr = [] |
| 174 | + while True: |
| 175 | + line = process.stderr.readline() |
| 176 | + if not line: |
| 177 | + break |
| 178 | + str_line = line.decode('utf-8') |
| 179 | + stderr.append(str_line) |
| 180 | + if access_denied_message in str_line: |
| 181 | + process.kill() |
| 182 | + with ux_utils.print_exception_no_traceback(): |
| 183 | + raise PermissionError( |
| 184 | + 'Failed to upload files to ' |
| 185 | + 'the remote bucket. The bucket does not have ' |
| 186 | + 'write permissions. It is possible that ' |
| 187 | + 'the bucket is public.') |
| 188 | + returncode = process.wait() |
| 189 | + if returncode != 0: |
| 190 | + stderr = '\n'.join(stderr) |
| 191 | + with ux_utils.print_exception_no_traceback(): |
| 192 | + logger.error(stderr) |
| 193 | + raise exceptions.StorageUploadError( |
| 194 | + f'Upload to bucket failed for store {bucket_name}. ' |
| 195 | + 'Please check the logs.') |
0 commit comments