forked from project-codeflare/codeflare-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjobs.py
204 lines (181 loc) · 7.21 KB
/
jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright 2023 IBM, Red Hat
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Optional, Dict, List
from pathlib import Path
from torchx.components.dist import ddp
from torchx.runner import get_runner, Runner
from torchx.schedulers.ray_scheduler import RayScheduler
from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo
if TYPE_CHECKING:
from ..cluster.cluster import Cluster
from ..cluster.cluster import get_current_namespace
all_jobs: List["Job"] = []
class JobDefinition(metaclass=abc.ABCMeta):
def _dry_run(self, cluster: "Cluster"):
pass
def submit(self, cluster: "Cluster"):
pass
class Job(metaclass=abc.ABCMeta):
def status(self):
pass
def logs(self):
pass
class DDPJobDefinition(JobDefinition):
def __init__(
self,
script: Optional[str] = None,
m: Optional[str] = None,
script_args: Optional[List[str]] = None,
name: Optional[str] = None,
cpu: Optional[int] = None,
gpu: Optional[int] = None,
memMB: Optional[int] = None,
h: Optional[str] = None,
j: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
max_retries: int = 0,
mounts: Optional[List[str]] = None,
rdzv_port: int = 29500,
rdzv_backend: str = None,
scheduler_args: Optional[Dict[str, str]] = None,
image: Optional[str] = None,
workspace: Optional[str] = f"file://{Path.cwd()}",
):
if bool(script) == bool(m): # logical XOR
raise ValueError(
"Exactly one of the following arguments must be defined: [script, m]."
)
self.script = script
self.m = m
self.script_args: List[str] = script_args if script_args is not None else []
self.name = name
self.cpu = cpu
self.gpu = gpu
self.memMB = memMB
self.h = h
self.j = j
self.env: Dict[str, str] = env if env is not None else dict()
self.max_retries = max_retries
self.mounts: List[str] = mounts if mounts is not None else []
self.rdzv_port = rdzv_port
self.rdzv_backend = rdzv_backend
self.scheduler_args: Dict[str, str] = (
scheduler_args if scheduler_args is not None else dict()
)
self.image = image
self.workspace = workspace
def _dry_run(self, cluster: "Cluster"):
j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus
runner = get_runner(ray_client=cluster.job_client)
runner._scheduler_instances["ray"] = RayScheduler(
session_name=runner._name, ray_client=cluster.job_client
)
return (
runner.dryrun(
app=ddp(
*self.script_args,
script=self.script,
m=self.m,
name=self.name,
h=self.h,
cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus,
gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus,
memMB=self.memMB
if self.memMB is not None
else cluster.config.max_memory * 1024,
j=self.j if self.j is not None else j,
env=self.env,
max_retries=self.max_retries,
rdzv_port=self.rdzv_port,
rdzv_backend=self.rdzv_backend
if self.rdzv_backend is not None
else "static",
mounts=self.mounts,
),
scheduler=cluster.torchx_scheduler,
cfg=cluster.torchx_config(**self.scheduler_args),
workspace=self.workspace,
),
runner,
)
def _missing_spec(self, spec: str):
raise ValueError(f"Job definition missing arg: {spec}")
def _dry_run_no_cluster(self):
if self.scheduler_args is not None:
if self.scheduler_args.get("namespace") is None:
self.scheduler_args["namespace"] = get_current_namespace()
runner = get_runner()
return (
runner.dryrun(
app=ddp(
*self.script_args,
script=self.script,
m=self.m,
name=self.name
if self.name is not None
else self._missing_spec("name"),
h=self.h,
cpu=self.cpu
if self.cpu is not None
else self._missing_spec("cpu (# cpus per worker)"),
gpu=self.gpu
if self.gpu is not None
else self._missing_spec("gpu (# gpus per worker)"),
memMB=self.memMB
if self.memMB is not None
else self._missing_spec("memMB (memory in MB)"),
j=self.j
if self.j is not None
else self._missing_spec(
"j (`workers`x`procs`)"
), # # of proc. = # of gpus,
env=self.env, # should this still exist?
max_retries=self.max_retries,
rdzv_port=self.rdzv_port, # should this still exist?
rdzv_backend=self.rdzv_backend
if self.rdzv_backend is not None
else "c10d",
mounts=self.mounts,
image=self.image
if self.image is not None
else self._missing_spec("image"),
),
scheduler="kubernetes_mcad",
cfg=self.scheduler_args,
workspace="",
),
runner,
)
def submit(self, cluster: "Cluster" = None) -> "Job":
return DDPJob(self, cluster)
class DDPJob(Job):
def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None):
self.job_definition = job_definition
self.cluster = cluster
if self.cluster:
definition, runner = job_definition._dry_run(cluster)
self._app_handle = runner.schedule(definition)
self._runner = runner
else:
definition, runner = job_definition._dry_run_no_cluster()
self._app_handle = runner.schedule(definition)
self._runner = runner
all_jobs.append(self)
def status(self) -> str:
return self._runner.status(self._app_handle)
def logs(self) -> str:
return "".join(self._runner.log_lines(self._app_handle, None))
def cancel(self):
self._runner.cancel(self._app_handle)