Skip to content

Commit af1b7fd

Browse files
authored
Add safe guard for provisioning/terminating TPU VM and fix spot launch TPU resource leak (skypilot-org#1500)
* safe guard * terminate the cluster to be safe * update * rm * better abstraction * comment * comments * rename * comments * comment * msg * comment * bug.. * msg * miss one place * output error msg
1 parent 2c0685d commit af1b7fd

File tree

8 files changed

+122
-43
lines changed

8 files changed

+122
-43
lines changed

sky/backends/backend_utils.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,34 +1245,58 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
12451245
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
12461246
f'\\(labels.ray-cluster-name={cluster_name}\\) '
12471247
f'--zone={zone} --format=value\\(name\\)')
1248-
if not get_internal_ips:
1249-
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
1250-
f' --zone {zone} --format="value[delimiter=\'\\n\']'
1251-
'(networkEndpoints.accessConfig.externalIp)"')
1252-
else:
1253-
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'
1254-
f' --zone {zone} --format="value[delimiter=\'\\n\']'
1255-
'(networkEndpoints.ipAddress)"')
1248+
returncode, stdout, stderr = log_lib.run_with_log(query_cmd,
1249+
'/dev/null',
1250+
shell=True,
1251+
stream_logs=False,
1252+
require_outputs=True)
1253+
subprocess_utils.handle_returncode(
1254+
returncode,
1255+
query_cmd,
1256+
'Failed to run gcloud to get TPU VM IDs.',
1257+
stderr=stdout + stderr)
1258+
if len(stdout) == 0:
1259+
logger.debug('No TPU VMs found with cluster name '
1260+
f'{cluster_name} in zone {zone}.')
1261+
if len(stdout.splitlines()) > 1:
1262+
# Rare case, this could mean resource leakage. Hint user.
1263+
logger.warning('Found more than one TPU VM/Pod with the same cluster '
1264+
f'name {cluster_name} in zone {zone}.')
1265+
1266+
all_ips = []
1267+
for tpu_id in stdout.splitlines():
1268+
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe {tpu_id}'
1269+
f' --zone {zone} --format=json')
1270+
returncode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
1271+
'/dev/null',
1272+
shell=True,
1273+
stream_logs=False,
1274+
require_outputs=True)
1275+
subprocess_utils.handle_returncode(
1276+
returncode,
1277+
tpuvm_cmd,
1278+
'Failed to run gcloud tpu-vm describe.',
1279+
stderr=stdout + stderr)
1280+
1281+
tpuvm_json = json.loads(stdout)
1282+
if tpuvm_json['state'] != 'READY':
1283+
# May be a leaked preempted resource.
1284+
logger.warning(f'TPU VM {tpu_id} is not in READY state. '
1285+
'Could be a garbage resource. Skipping...')
1286+
continue
1287+
1288+
if not get_internal_ips:
1289+
ips = [
1290+
endpoint['accessConfig']['externalIp']
1291+
for endpoint in tpuvm_json['networkEndpoints']
1292+
]
1293+
else:
1294+
ips = [
1295+
endpoint['ipAddress']
1296+
for endpoint in tpuvm_json['networkEndpoints']
1297+
]
1298+
all_ips.extend(ips)
12561299

1257-
rcode, stdout, stderr = log_lib.run_with_log(tpuvm_cmd,
1258-
'/dev/null',
1259-
shell=True,
1260-
stream_logs=False,
1261-
require_outputs=True)
1262-
if rcode != 0:
1263-
failure_massage = ('Failed to run gcloud to get TPU VM Pod IPs.\n'
1264-
'**** STDOUT ****\n'
1265-
'{stdout}\n'
1266-
'**** STDERR ****\n'
1267-
'{stderr}\n'
1268-
'**** CMD ****\n'
1269-
'{tpuvm_cmd}')
1270-
with ux_utils.print_exception_no_traceback():
1271-
raise RuntimeError(
1272-
failure_massage.format(stdout=stdout,
1273-
stderr=stderr,
1274-
tpuvm_cmd=tpuvm_cmd))
1275-
all_ips = re.findall(IP_ADDR_REGEX, stdout)
12761300
return all_ips
12771301

12781302

sky/backends/cloud_vm_ray_backend.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,9 +2693,29 @@ def teardown_no_lock(self,
26932693
f'gcloud compute tpus tpu-vm list --filter='
26942694
f'\\(labels.ray-cluster-name={cluster_name}\\) '
26952695
f'--zone={zone} --format=value\\(name\\)')
2696-
terminate_cmd = (
2697-
f'gcloud compute tpus tpu-vm delete --zone={zone}'
2698-
f' --quiet $({query_cmd})')
2696+
returncode, stdout, stderr = log_lib.run_with_log(
2697+
query_cmd,
2698+
log_abs_path,
2699+
shell=True,
2700+
stream_logs=False,
2701+
require_outputs=True)
2702+
2703+
# Skip the termination command, if the TPU ID
2704+
# query command fails.
2705+
if returncode != 0:
2706+
terminate_cmd = (f'echo "cmd: {query_cmd}" && '
2707+
f'echo "{stdout}" && '
2708+
f'echo "{stderr}" >&2 && '
2709+
f'exit {returncode}')
2710+
else:
2711+
# Needs to create a list as GCP does not allow deleting
2712+
# multiple TPU VMs at once.
2713+
tpu_terminate_cmds = []
2714+
for tpu_id in stdout.splitlines():
2715+
tpu_terminate_cmds.append(
2716+
'gcloud compute tpus tpu-vm delete '
2717+
f'--zone={zone} --quiet {tpu_id}')
2718+
terminate_cmd = ' && '.join(tpu_terminate_cmds)
26992719
else:
27002720
query_cmd = (
27012721
f'gcloud compute instances list --filter='

sky/clouds/cloud.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,5 +214,18 @@ def accelerator_in_region_or_zone(self,
214214
"""Returns whether the accelerator is valid in the region or zone."""
215215
raise NotImplementedError
216216

217+
def need_cleanup_after_preemption(self,
218+
resource: 'resources.Resources') -> bool:
219+
"""Returns whether a spot resource needs cleanup after preeemption.
220+
221+
In most cases, spot resources do not need cleanup after preemption,
222+
as long as the cluster can be relaunched with the same name and tag,
223+
no matter the preemption behavior is to terminate or stop the cluster.
224+
The only exception by far is GCP's Spot TPU VM. We override this method
225+
in gcp.py.
226+
"""
227+
del resource
228+
return False
229+
217230
def __repr__(self):
218231
return self._REPR

sky/clouds/gcp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,18 @@ def accelerator_in_region_or_zone(self,
468468
return service_catalog.accelerator_in_region_or_zone(
469469
accelerator, acc_count, region, zone, 'gcp')
470470

471+
def need_cleanup_after_preemption(self,
472+
resources: 'resources.Resources') -> bool:
473+
"""Returns whether a spot resource needs cleanup after preeemption."""
474+
# Spot TPU VMs require manual cleanup after preemption.
475+
# "If your Cloud TPU is preempted,
476+
# you must delete it and create a new one ..."
477+
# See: https://cloud.google.com/tpu/docs/preemptible#tpu-vm
478+
479+
# pylint: disable=import-outside-toplevel
480+
from sky.utils import tpu_utils
481+
return tpu_utils.is_tpu_vm(resources)
482+
471483
@classmethod
472484
def get_project_id(cls, dryrun: bool = False) -> str:
473485
# TODO(zhwu): change the project id fetching with the following command

sky/resources.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ def _set_accelerators(
260260
def is_launchable(self) -> bool:
261261
return self.cloud is not None and self._instance_type is not None
262262

263+
def need_cleanup_after_preemption(self) -> bool:
264+
"""Returns whether a spot resource needs cleanup after preeemption."""
265+
return self.cloud.need_cleanup_after_preemption(self)
266+
263267
def _set_region_zone(self, region: Optional[str],
264268
zone: Optional[str]) -> None:
265269
if region is None and zone is None:

sky/spot/controller.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ def _run(self):
145145
'cluster is healthy. Try to recover the job '
146146
'(the cluster will not be restarted).')
147147

148+
resources = list(self._task.resources)[0]
149+
if resources.need_cleanup_after_preemption():
150+
# Some spot resource (e.g., Spot TPU VM) may need to be
151+
# cleaned up after preemption.
152+
logger.info('Cleaning up the preempted spot cluster...')
153+
self._strategy_executor.terminate_cluster()
154+
148155
# Try to recover the spot jobs, when the cluster is preempted
149156
# or the job status is failed to be fetched.
150157
spot_state.set_recovering(self._job_id)

sky/spot/recovery_strategy.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
276276
launched_resources.region)
277277
return launch_time
278278

279-
def terminate_cluster(self, max_retry: int = 3) -> None:
280-
super().terminate_cluster(max_retry)
281-
self._launched_cloud_region = None
282-
283279
def recover(self) -> float:
284280
# 1. Cancel the jobs and launch the cluster with the STOPPED status,
285281
# so that it will try on the current region first until timeout.
@@ -313,7 +309,9 @@ def recover(self) -> float:
313309
return launched_time
314310

315311
# Step 2
316-
logger.debug('Terminating unhealthy spot cluster.')
312+
logger.debug('Terminating unhealthy spot cluster and '
313+
'reset cloud region.')
314+
self._launched_cloud_region = None
317315
self.terminate_cluster()
318316

319317
# Step 3

sky/utils/tpu_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,29 @@
44
from sky import resources as resources_lib
55

66

7-
def is_tpu(resources: resources_lib.Resources) -> bool:
8-
if resources.accelerators is None:
7+
def is_tpu(resources: Optional[resources_lib.Resources]) -> bool:
8+
if resources is None or resources.accelerators is None:
99
return False
1010
acc, _ = list(resources.accelerators.items())[0]
1111
return acc.startswith('tpu')
1212

1313

14-
def is_tpu_vm(resources: resources_lib.Resources) -> bool:
15-
if resources.accelerator_args is None:
14+
def is_tpu_vm(resources: Optional[resources_lib.Resources]) -> bool:
15+
if resources is None or resources.accelerator_args is None:
1616
return False
1717
return resources.accelerator_args.get('tpu_vm', False)
1818

1919

20-
def is_tpu_vm_pod(resources: resources_lib.Resources) -> bool:
21-
if not is_tpu_vm(resources):
20+
def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool:
21+
if resources is None or not is_tpu_vm(resources):
2222
return False
2323
acc, _ = list(resources.accelerators.items())[0]
2424
return acc not in ['tpu-v2-8', 'tpu-v3-8']
2525

2626

27-
def get_num_tpu_devices(resources: resources_lib.Resources) -> Optional[int]:
28-
if not is_tpu(resources):
27+
def get_num_tpu_devices(
28+
resources: Optional[resources_lib.Resources]) -> Optional[int]:
29+
if resources is None or not is_tpu(resources):
2930
return None
3031
acc, _ = list(resources.accelerators.items())[0]
3132
num_tpu_devices = int(int(acc.split('-')[2]) / 8)

0 commit comments

Comments
 (0)