Skip to content

Commit ee73e7d

Browse files
authored
Clean up preempted resources for TPU (skypilot-org#1483)
* fix in controller * remove debug msg * msg * handle job_status == None case and refactor * space * update * comments * comments
1 parent 172f6e3 commit ee73e7d

File tree

3 files changed

+41
-45
lines changed

3 files changed

+41
-45
lines changed

sky/backends/backend_utils.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1225,11 +1225,8 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
12251225

12261226
cluster_name = ray_config['cluster_name']
12271227
zone = ray_config['provider']['availability_zone']
1228-
# Excluding preempted VMs is safe as they are already terminated and
1229-
# do not charge.
12301228
query_cmd = (f'gcloud compute tpus tpu-vm list --filter='
1231-
f'"(labels.ray-cluster-name={cluster_name} AND '
1232-
f'state!=PREEMPTED)" '
1229+
f'\\(labels.ray-cluster-name={cluster_name}\\) '
12331230
f'--zone={zone} --format=value\\(name\\)')
12341231
if not get_internal_ips:
12351232
tpuvm_cmd = (f'gcloud compute tpus tpu-vm describe $({query_cmd})'

sky/backends/cloud_vm_ray_backend.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -2680,12 +2680,9 @@ def teardown_no_lock(self,
26802680
# check if gcloud includes TPU VM API
26812681
backend_utils.check_gcp_cli_include_tpu_vm()
26822682

2683-
# Excluding preempted VMs is safe as they are already
2684-
# terminated and do not charge.
26852683
query_cmd = (
26862684
f'gcloud compute tpus tpu-vm list --filter='
2687-
f'"(labels.ray-cluster-name={cluster_name} AND '
2688-
f'state!=PREEMPTED)" '
2685+
f'\\(labels.ray-cluster-name={cluster_name}\\) '
26892686
f'--zone={zone} --format=value\\(name\\)')
26902687
terminate_cmd = (
26912688
f'gcloud compute tpus tpu-vm delete --zone={zone}'

sky/spot/controller.py

+39-37
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,6 @@ def _run(self):
8585
job_status = spot_utils.get_job_status(self._backend,
8686
self._cluster_name)
8787

88-
if job_status is not None and not job_status.is_terminal():
89-
need_recovery = False
90-
if self._task.num_nodes > 1:
91-
# Check the cluster status for multi-node jobs, since the
92-
# job may not be set to FAILED immediately when only some
93-
# of the nodes are preempted.
94-
(cluster_status,
95-
handle) = backend_utils.refresh_cluster_status_handle(
96-
self._cluster_name, force_refresh=True)
97-
if cluster_status != global_user_state.ClusterStatus.UP:
98-
# recover the cluster if it is not up.
99-
# The status could be None when the cluster is preempted
100-
# right after the job was found FAILED.
101-
cluster_status_str = ('is preempted'
102-
if cluster_status is None else
103-
f'status {cluster_status.value}')
104-
logger.info(f'Cluster {cluster_status_str}. '
105-
'Recovering...')
106-
need_recovery = True
107-
if not need_recovery:
108-
# The job and cluster are healthy, continue to monitor the
109-
# job status.
110-
continue
111-
11288
if job_status == job_lib.JobStatus.SUCCEEDED:
11389
end_time = spot_utils.get_job_timestamp(self._backend,
11490
self._cluster_name,
@@ -117,14 +93,35 @@ def _run(self):
11793
spot_state.set_succeeded(self._job_id, end_time=end_time)
11894
break
11995

120-
if job_status == job_lib.JobStatus.FAILED:
121-
# Check the status of the spot cluster. If it is not UP,
122-
# the cluster is preempted.
123-
(cluster_status,
124-
handle) = backend_utils.refresh_cluster_status_handle(
125-
self._cluster_name, force_refresh=True)
126-
if cluster_status == global_user_state.ClusterStatus.UP:
127-
# The user code has probably crashed.
96+
# For single-node jobs, nonterminated job_status indicates a
97+
# healthy cluster. We can safely continue monitoring.
98+
# For multi-node jobs, since the job may not be set to FAILED
99+
# immediately (depending on user program) when only some of the
100+
# nodes are preempted, need to check the actual cluster status.
101+
if (job_status is not None and not job_status.is_terminal() and
102+
self._task.num_nodes == 1):
103+
continue
104+
105+
# Pull the actual cluster status from the cloud provider to
106+
# determine whether the cluster is preempted.
107+
(cluster_status,
108+
handle) = backend_utils.refresh_cluster_status_handle(
109+
self._cluster_name, force_refresh=True)
110+
111+
if cluster_status != global_user_state.ClusterStatus.UP:
112+
# The cluster is (partially) preempted. It can be down, INIT
113+
# or STOPPED, based on the interruption behavior of the cloud.
114+
# Spot recovery is needed (will be done later in the code).
115+
cluster_status_str = ('' if cluster_status is None else
116+
f' (status: {cluster_status.value})')
117+
logger.info(
118+
f'Cluster is preempted{cluster_status_str}. Recovering...')
119+
else:
120+
if job_status is not None and not job_status.is_terminal():
121+
# The multi-node job is still running, continue monitoring.
122+
continue
123+
elif job_status == job_lib.JobStatus.FAILED:
124+
# The user code has probably crashed, fail immediately.
128125
end_time = spot_utils.get_job_timestamp(self._backend,
129126
self._cluster_name,
130127
get_end_time=True)
@@ -140,11 +137,16 @@ def _run(self):
140137
failure_type=spot_state.SpotStatus.FAILED,
141138
end_time=end_time)
142139
break
143-
# cluster can be down, INIT or STOPPED, based on the interruption
144-
# behavior of the cloud.
145-
# Failed to connect to the cluster or the cluster is partially down.
146-
# job_status is None or job_status == job_lib.JobStatus.FAILED
147-
logger.info('The cluster is preempted.')
140+
# Although the cluster is healthy, we fail to access the
141+
# job status. Try to recover the job (will not restart the
142+
# cluster, if the cluster is healthy).
143+
assert job_status is None, job_status
144+
logger.info('Failed to fetch the job status while the '
145+
'cluster is healthy. Try to recover the job '
146+
'(the cluster will not be restarted).')
147+
148+
# Try to recover the spot jobs, when the cluster is preempted
149+
# or the job status is failed to be fetched.
148150
spot_state.set_recovering(self._job_id)
149151
recovered_time = self._strategy_executor.recover()
150152
spot_state.set_recovered(self._job_id,

0 commit comments

Comments
 (0)