Skip to content

Commit d8a6b3b

Browse files
authored
Add zone support in YAML (skypilot-org#1014)
* Add zone support * Zone validation and acc validation * Remove CLI overwrite * yapf * fix * address comments * yapf * address comments * yapf * fix * fix and remove test * fix * fix test * fix huggingface smoke test * handle multi-node * Replace ValueError with assert * fix multi-node * replace sky.cloud with clouds.cloud * Add invalid zone test * add smoke test * yapf * Support region/zone close matching * Don't overwrite ray yaml * add comments * add zone info to usage_lib * update docstring * roll back sky.clouds * refactor * missed * Fix * yapf * add zone to cli * update test * fix * Add acc_count constraint * remove unused func * shorten * comments * fix
1 parent f02f0f6 commit d8a6b3b

22 files changed

+410
-103
lines changed

docs/source/reference/yaml-spec.rst

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ describe all fields available.
3434
# if this is specified.
3535
region: us-east-1
3636
37+
# The zone to use (optional).
38+
region: us-east-1a
39+
3740
# Accelerator name and count per node (optional).
3841
#
3942
# Use `sky show-gpus` to view available accelerator configurations.

examples/horovod_distributed_tf_app.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import sky
66
import time_estimators
7-
from sky import clouds
87

98
IPAddr = str
109

@@ -55,7 +54,7 @@ def run_fn(ip_list: List[IPAddr]) -> Dict[IPAddr, str]:
5554
estimated_size_gigabytes=70)
5655
train.set_outputs('resnet-model-dir', estimated_size_gigabytes=0.1)
5756
train.set_resources({
58-
sky.Resources(clouds.AWS(), 'p3.2xlarge'),
57+
sky.Resources(sky.AWS(), 'p3.2xlarge'),
5958
})
6059

6160
dag = sky.Optimizer.optimize(dag)

examples/ray_tune_app.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import List, Optional
33

44
import sky
5-
from sky import clouds
65

76
with sky.Dag() as dag:
87
# Total Nodes, INCLUDING Head Node
@@ -31,7 +30,7 @@ def run_fn(node_rank: int, ip_list: List[str]) -> Optional[str]:
3130
)
3231

3332
train.set_resources({
34-
sky.Resources(clouds.AWS(), 'p3.2xlarge'),
33+
sky.Resources(sky.AWS(), 'p3.2xlarge'),
3534
})
3635

3736
sky.launch(dag)

sky/backends/backend_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ def _process_cli_query(
10751075
f'{stdout}\n'
10761076
'**** STDERR ****\n'
10771077
f'{stderr}')
1078-
if (cloud == str(sky.Azure()) and returncode == 2 and
1078+
if (cloud == str(clouds.Azure()) and returncode == 2 and
10791079
'argument --ids: expected at least one argument' in stderr):
10801080
# Azure CLI has a returncode 2 when the cluster is not found, as
10811081
# --ids <empty> is passed to the query command. In that case, the

sky/backends/cloud_vm_ray_backend.py

+52-15
Original file line numberDiff line numberDiff line change
@@ -686,25 +686,32 @@ def _yield_region_zones(self, to_provision: resources_lib.Resources,
686686
prev_resources = handle.launched_resources
687687
if prev_resources is not None and cloud.is_same_cloud(
688688
prev_resources.cloud):
689-
if cloud.is_same_cloud(sky.GCP()) or cloud.is_same_cloud(
690-
sky.AWS()):
689+
if cloud.is_same_cloud(clouds.GCP()) or cloud.is_same_cloud(
690+
clouds.AWS()):
691691
region = config['provider']['region']
692692
zones = config['provider']['availability_zone']
693-
elif cloud.is_same_cloud(sky.Azure()):
693+
elif cloud.is_same_cloud(clouds.Azure()):
694694
region = config['provider']['location']
695695
zones = None
696-
elif cloud.is_same_cloud(sky.Local()):
696+
elif cloud.is_same_cloud(clouds.Local()):
697697
local_regions = clouds.Local.regions()
698698
region = local_regions[0].name
699699
zones = None
700700
else:
701701
assert False, cloud
702-
if region != prev_resources.region:
703-
raise ValueError(
704-
f'Region mismatch. The region in '
705-
f'{handle.cluster_yaml} '
706-
'has been changed from '
707-
f'{prev_resources.region} to {region}.')
702+
assert region == prev_resources.region, (
703+
f'Region mismatch. The region in '
704+
f'{handle.cluster_yaml} '
705+
'has been changed from '
706+
f'{prev_resources.region} to {region}.')
707+
assert (zones is None or prev_resources.zone is None or
708+
prev_resources.zone
709+
in zones), (f'{prev_resources.zone} not found in '
710+
f'zones of {handle.cluster_yaml}.')
711+
# Note that we don't overwrite the zone field in Ray YAML
712+
# even if prev_resources.zone != zones.
713+
# This is because Ray will consider the YAML hash changed
714+
# and not reuse the existing cluster.
708715
except FileNotFoundError:
709716
# Happens if no previous cluster.yaml exists.
710717
pass
@@ -798,10 +805,15 @@ def _yield_region_zones(self, to_provision: resources_lib.Resources,
798805
accelerators=to_provision.accelerators,
799806
use_spot=to_provision.use_spot,
800807
):
801-
# Do not retry on region if it's not in the requested region.
808+
# Only retry requested region/zones or all if not specified.
802809
if (to_provision.region is not None and
803810
region.name != to_provision.region):
804811
continue
812+
if to_provision.zone is not None:
813+
zones_name = [zone.name for zone in zones]
814+
if to_provision.zone not in zones_name:
815+
continue
816+
zones = [clouds.Zone(name=to_provision.zone)]
805817
yield (region, zones)
806818

807819
def _try_provision_tpu(self, to_provision: resources_lib.Resources,
@@ -1410,12 +1422,12 @@ def _update_cluster_region(self):
14101422
config = common_utils.read_yaml(self.cluster_yaml)
14111423
provider = config['provider']
14121424
cloud = self.launched_resources.cloud
1413-
if cloud.is_same_cloud(sky.Azure()):
1425+
if cloud.is_same_cloud(clouds.Azure()):
14141426
region = provider['location']
1415-
elif cloud.is_same_cloud(sky.GCP()) or cloud.is_same_cloud(
1416-
sky.AWS()):
1427+
elif cloud.is_same_cloud(clouds.GCP()) or cloud.is_same_cloud(
1428+
clouds.AWS()):
14171429
region = provider['region']
1418-
elif cloud.is_same_cloud(sky.Local()):
1430+
elif cloud.is_same_cloud(clouds.Local()):
14191431
# There is only 1 region for Local cluster, 'Local'.
14201432
local_regions = clouds.Local.regions()
14211433
region = local_regions[0].name
@@ -1495,6 +1507,16 @@ def check_resources_fit_cluster(self, handle: ResourceHandle,
14951507
'Task requested resources in region '
14961508
f'{task_resources.region!r}, but the existing cluster '
14971509
f'is in region {launched_resources.region!r}.')
1510+
if (task_resources.zone is not None and
1511+
task_resources.zone != launched_resources.zone):
1512+
zone_str = (f'is in zone {launched_resources.zone!r}.'
1513+
if launched_resources.zone is not None else
1514+
'does not have zone specified.')
1515+
with ux_utils.print_exception_no_traceback():
1516+
raise exceptions.ResourcesMismatchError(
1517+
'Task requested resources in zone '
1518+
f'{task_resources.zone!r}, but the existing cluster '
1519+
f'{zone_str}')
14981520
with ux_utils.print_exception_no_traceback():
14991521
raise exceptions.ResourcesMismatchError(
15001522
'Requested resources do not match the existing cluster.\n'
@@ -1620,6 +1642,21 @@ def _provision(self,
16201642
# TPU.
16211643
tpu_create_script=config_dict.get('tpu-create-script'),
16221644
tpu_delete_script=config_dict.get('tpu-delete-script'))
1645+
1646+
# Get actual zone info and save it into handle
1647+
get_zone_cmd = handle.launched_resources.cloud.get_zone_shell_cmd()
1648+
if get_zone_cmd is not None:
1649+
# We leave the zone field to None for multi-node cases
1650+
# if zone is not specified because head and worker nodes
1651+
# can be launched in different zones.
1652+
if (task.num_nodes == 1 or
1653+
handle.launched_resources.zone is not None):
1654+
returncode, stdout, _ = self.run_on_head(
1655+
handle, get_zone_cmd, require_outputs=True)
1656+
# zone will be checked during Resources cls initialization.
1657+
handle.launched_resources = handle.launched_resources.copy(
1658+
zone=stdout.strip())
1659+
16231660
usage_lib.messages.usage.update_cluster_resources(
16241661
handle.launched_nodes, handle.launched_resources)
16251662
usage_lib.messages.usage.update_final_cluster_status(

sky/cli.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,12 @@ def _parse_env_var(env_var: str) -> Tuple[str, str]:
249249
type=str,
250250
help=('The region to use. If specified, overrides the '
251251
'"resources.region" config. Passing "none" resets the config.')),
252+
click.option(
253+
'--zone',
254+
required=False,
255+
type=str,
256+
help=('The zone to use. If specified, overrides the '
257+
'"resources.zone" config. Passing "none" resets the config.')),
252258
click.option(
253259
'--num-nodes',
254260
required=False,
@@ -328,6 +334,7 @@ def _add_options(func):
328334

329335
def _parse_override_params(cloud: Optional[str] = None,
330336
region: Optional[str] = None,
337+
zone: Optional[str] = None,
331338
gpus: Optional[str] = None,
332339
instance_type: Optional[str] = None,
333340
use_spot: Optional[bool] = None,
@@ -345,6 +352,11 @@ def _parse_override_params(cloud: Optional[str] = None,
345352
override_params['region'] = None
346353
else:
347354
override_params['region'] = region
355+
if zone is not None:
356+
if zone.lower() == 'none':
357+
override_params['zone'] = None
358+
else:
359+
override_params['zone'] = zone
348360
if gpus is not None:
349361
if gpus.lower() == 'none':
350362
override_params['accelerators'] = None
@@ -684,6 +696,7 @@ def _make_dag_from_entrypoint_with_overrides(
684696
workdir: Optional[str] = None,
685697
cloud: Optional[str] = None,
686698
region: Optional[str] = None,
699+
zone: Optional[str] = None,
687700
gpus: Optional[str] = None,
688701
instance_type: Optional[str] = None,
689702
num_nodes: Optional[int] = None,
@@ -695,7 +708,6 @@ def _make_dag_from_entrypoint_with_overrides(
695708
spot_recovery: Optional[str] = None,
696709
) -> sky.Dag:
697710
entrypoint = ' '.join(entrypoint)
698-
699711
with sky.Dag() as dag:
700712
is_yaml, yaml_config = _check_yaml(entrypoint)
701713
if is_yaml:
@@ -726,6 +738,7 @@ def _make_dag_from_entrypoint_with_overrides(
726738

727739
override_params = _parse_override_params(cloud=cloud,
728740
region=region,
741+
zone=zone,
729742
gpus=gpus,
730743
instance_type=instance_type,
731744
use_spot=use_spot,
@@ -877,6 +890,7 @@ def launch(
877890
workdir: Optional[str],
878891
cloud: Optional[str],
879892
region: Optional[str],
893+
zone: Optional[str],
880894
gpus: Optional[str],
881895
instance_type: Optional[str],
882896
num_nodes: Optional[int],
@@ -908,6 +922,7 @@ def launch(
908922
workdir=workdir,
909923
cloud=cloud,
910924
region=region,
925+
zone=zone,
911926
gpus=gpus,
912927
instance_type=instance_type,
913928
num_nodes=num_nodes,
@@ -956,6 +971,7 @@ def exec(
956971
name: Optional[str],
957972
cloud: Optional[str],
958973
region: Optional[str],
974+
zone: Optional[str],
959975
workdir: Optional[str],
960976
gpus: Optional[str],
961977
instance_type: Optional[str],
@@ -1038,6 +1054,7 @@ def exec(
10381054
workdir=workdir,
10391055
cloud=cloud,
10401056
region=region,
1057+
zone=zone,
10411058
gpus=gpus,
10421059
instance_type=instance_type,
10431060
use_spot=use_spot,
@@ -2239,6 +2256,7 @@ def spot_launch(
22392256
workdir: Optional[str],
22402257
cloud: Optional[str],
22412258
region: Optional[str],
2259+
zone: Optional[str],
22422260
gpus: Optional[str],
22432261
instance_type: Optional[str],
22442262
num_nodes: Optional[int],
@@ -2262,6 +2280,7 @@ def spot_launch(
22622280
workdir=workdir,
22632281
cloud=cloud,
22642282
region=region,
2283+
zone=zone,
22652284
gpus=gpus,
22662285
instance_type=instance_type,
22672286
num_nodes=num_nodes,
@@ -2510,6 +2529,7 @@ def benchmark_launch(
25102529
workdir: Optional[str],
25112530
cloud: Optional[str],
25122531
region: Optional[str],
2532+
zone: Optional[str],
25132533
gpus: Optional[str],
25142534
num_nodes: Optional[int],
25152535
use_spot: Optional[bool],
@@ -2556,6 +2576,9 @@ def benchmark_launch(
25562576
if region is not None:
25572577
if any('region' in candidate for candidate in candidates):
25582578
raise click.BadParameter(f'region {message}')
2579+
if zone is not None:
2580+
if any('zone' in candidate for candidate in candidates):
2581+
raise click.BadParameter(f'zone {message}')
25592582
if gpus is not None:
25602583
if any('accelerators' in candidate for candidate in candidates):
25612584
raise click.BadParameter(f'gpus (accelerators) {message}')
@@ -2605,6 +2628,7 @@ def benchmark_launch(
26052628
config['num_nodes'] = num_nodes
26062629
override_params = _parse_override_params(cloud=cloud,
26072630
region=region,
2631+
zone=zone,
26082632
gpus=gpus,
26092633
use_spot=use_spot,
26102634
image_id=image_id,
@@ -2617,6 +2641,9 @@ def benchmark_launch(
26172641
if 'region' in resources_config:
26182642
if resources_config['region'] is None:
26192643
resources_config.pop('region')
2644+
if 'zone' in resources_config:
2645+
if resources_config['zone'] is None:
2646+
resources_config.pop('zone')
26202647
if 'accelerators' in resources_config:
26212648
if resources_config['accelerators'] is None:
26222649
resources_config.pop('accelerators')

sky/clouds/aws.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ def get_default_ami(cls, region_name: str, instance_type: str) -> str:
120120
assert region_name in amis, region_name
121121
return amis[region_name]
122122

123+
@classmethod
124+
def get_zone_shell_cmd(cls) -> Optional[str]:
125+
# The command for getting the current zone is from:
126+
# https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instance-identity-documents.html # pylint: disable=line-too-long
127+
command_str = (
128+
'curl -s http://169.254.169.254/latest/dynamic/instance-identity/document' # pylint: disable=line-too-long
129+
' | python3 -u -c "import sys, json; '
130+
'print(json.load(sys.stdin)[\'availabilityZone\'])"')
131+
return command_str
132+
123133
#### Normal methods ####
124134

125135
def instance_type_to_hourly_cost(self, instance_type: str, use_spot: bool):
@@ -309,5 +319,13 @@ def get_credential_file_mounts(self) -> Dict[str, str]:
309319
def instance_type_exists(self, instance_type):
310320
return service_catalog.instance_type_exists(instance_type, clouds='aws')
311321

312-
def region_exists(self, region: str) -> bool:
313-
return service_catalog.region_exists(region, 'aws')
322+
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
323+
return service_catalog.validate_region_zone(region, zone, clouds='aws')
324+
325+
def accelerator_in_region_or_zone(self,
326+
accelerator: str,
327+
acc_count: int,
328+
region: Optional[str] = None,
329+
zone: Optional[str] = None) -> bool:
330+
return service_catalog.accelerator_in_region_or_zone(
331+
accelerator, acc_count, region, zone, 'aws')

sky/clouds/azure.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ def get_accelerators_from_instance_type(
159159
return service_catalog.get_accelerators_from_instance_type(
160160
instance_type, clouds='azure')
161161

162+
@classmethod
163+
def get_zone_shell_cmd(cls) -> Optional[str]:
164+
return None
165+
162166
def make_deploy_resources_variables(
163167
self, resources: 'resources.Resources',
164168
region: Optional['clouds.Region'],
@@ -281,8 +285,18 @@ def instance_type_exists(self, instance_type):
281285
return service_catalog.instance_type_exists(instance_type,
282286
clouds='azure')
283287

284-
def region_exists(self, region: str) -> bool:
285-
return service_catalog.region_exists(region, 'azure')
288+
def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
289+
return service_catalog.validate_region_zone(region,
290+
zone,
291+
clouds='azure')
292+
293+
def accelerator_in_region_or_zone(self,
294+
accelerator: str,
295+
acc_count: int,
296+
region: Optional[str] = None,
297+
zone: Optional[str] = None) -> bool:
298+
return service_catalog.accelerator_in_region_or_zone(
299+
accelerator, acc_count, region, zone, 'azure')
286300

287301
@classmethod
288302
def get_project_id(cls, dryrun: bool = False) -> str:

0 commit comments

Comments
 (0)