12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- """
16
- The jobs sub-module contains methods needed to submit Distributed Data Parallel(DDP) jobs to Ray Clusters created by the CodeFlare SDK.
17
- """
18
-
19
15
import abc
20
16
from typing import TYPE_CHECKING , Optional , Dict , List
21
17
from pathlib import Path
22
18
23
- from torchx .components .dist import ddp
24
19
from torchx .runner import get_runner , Runner
25
20
from torchx .schedulers .ray_scheduler import RayScheduler
26
21
from torchx .specs import AppHandle , parse_app_handle , AppDryRunInfo
@@ -47,161 +42,3 @@ def status(self):
47
42
48
43
def logs (self ):
49
44
pass
50
-
51
-
52
- class DDPJobDefinition (JobDefinition ):
53
- def __init__ (
54
- self ,
55
- script : Optional [str ] = None ,
56
- m : Optional [str ] = None ,
57
- script_args : Optional [List [str ]] = None ,
58
- name : Optional [str ] = None ,
59
- cpu : Optional [int ] = None ,
60
- gpu : Optional [int ] = None ,
61
- memMB : Optional [int ] = None ,
62
- h : Optional [str ] = None ,
63
- j : Optional [str ] = None ,
64
- env : Optional [Dict [str , str ]] = None ,
65
- max_retries : int = 0 ,
66
- mounts : Optional [List [str ]] = None ,
67
- rdzv_port : int = 29500 ,
68
- rdzv_backend : str = None ,
69
- scheduler_args : Optional [Dict [str , str ]] = None ,
70
- image : Optional [str ] = None ,
71
- workspace : Optional [str ] = f"file://{ Path .cwd ()} " ,
72
- ):
73
- if bool (script ) == bool (m ): # logical XOR
74
- raise ValueError (
75
- "Exactly one of the following arguments must be defined: [script, m]."
76
- )
77
- self .script = script
78
- self .m = m
79
- self .script_args : List [str ] = script_args if script_args is not None else []
80
- self .name = name
81
- self .cpu = cpu
82
- self .gpu = gpu
83
- self .memMB = memMB
84
- self .h = h
85
- self .j = j
86
- self .env : Dict [str , str ] = env if env is not None else dict ()
87
- self .max_retries = max_retries
88
- self .mounts : List [str ] = mounts if mounts is not None else []
89
- self .rdzv_port = rdzv_port
90
- self .rdzv_backend = rdzv_backend
91
- self .scheduler_args : Dict [str , str ] = (
92
- scheduler_args if scheduler_args is not None else dict ()
93
- )
94
- self .image = image
95
- self .workspace = workspace
96
-
97
- def _dry_run (self , cluster : "Cluster" ):
98
- j = f"{ cluster .config .num_workers } x{ max (cluster .config .num_gpus , 1 )} " # # of proc. = # of gpus
99
- runner = get_runner (ray_client = cluster .job_client )
100
- runner ._scheduler_instances ["ray" ] = RayScheduler (
101
- session_name = runner ._name , ray_client = cluster .job_client
102
- )
103
- return (
104
- runner .dryrun (
105
- app = ddp (
106
- * self .script_args ,
107
- script = self .script ,
108
- m = self .m ,
109
- name = self .name ,
110
- h = self .h ,
111
- cpu = self .cpu if self .cpu is not None else cluster .config .max_cpus ,
112
- gpu = self .gpu if self .gpu is not None else cluster .config .num_gpus ,
113
- memMB = self .memMB
114
- if self .memMB is not None
115
- else cluster .config .max_memory * 1024 ,
116
- j = self .j if self .j is not None else j ,
117
- env = self .env ,
118
- max_retries = self .max_retries ,
119
- rdzv_port = self .rdzv_port ,
120
- rdzv_backend = self .rdzv_backend
121
- if self .rdzv_backend is not None
122
- else "static" ,
123
- mounts = self .mounts ,
124
- ),
125
- scheduler = cluster .torchx_scheduler ,
126
- cfg = cluster .torchx_config (** self .scheduler_args ),
127
- workspace = self .workspace ,
128
- ),
129
- runner ,
130
- )
131
-
132
- def _missing_spec (self , spec : str ):
133
- raise ValueError (f"Job definition missing arg: { spec } " )
134
-
135
- def _dry_run_no_cluster (self ):
136
- if self .scheduler_args is not None :
137
- if self .scheduler_args .get ("namespace" ) is None :
138
- self .scheduler_args ["namespace" ] = get_current_namespace ()
139
- runner = get_runner ()
140
- return (
141
- runner .dryrun (
142
- app = ddp (
143
- * self .script_args ,
144
- script = self .script ,
145
- m = self .m ,
146
- name = self .name
147
- if self .name is not None
148
- else self ._missing_spec ("name" ),
149
- h = self .h ,
150
- cpu = self .cpu
151
- if self .cpu is not None
152
- else self ._missing_spec ("cpu (# cpus per worker)" ),
153
- gpu = self .gpu
154
- if self .gpu is not None
155
- else self ._missing_spec ("gpu (# gpus per worker)" ),
156
- memMB = self .memMB
157
- if self .memMB is not None
158
- else self ._missing_spec ("memMB (memory in MB)" ),
159
- j = self .j
160
- if self .j is not None
161
- else self ._missing_spec (
162
- "j (`workers`x`procs`)"
163
- ), # # of proc. = # of gpus,
164
- env = self .env , # should this still exist?
165
- max_retries = self .max_retries ,
166
- rdzv_port = self .rdzv_port , # should this still exist?
167
- rdzv_backend = self .rdzv_backend
168
- if self .rdzv_backend is not None
169
- else "c10d" ,
170
- mounts = self .mounts ,
171
- image = self .image
172
- if self .image is not None
173
- else self ._missing_spec ("image" ),
174
- ),
175
- scheduler = "kubernetes_mcad" ,
176
- cfg = self .scheduler_args ,
177
- workspace = "" ,
178
- ),
179
- runner ,
180
- )
181
-
182
- def submit (self , cluster : "Cluster" = None ) -> "Job" :
183
- return DDPJob (self , cluster )
184
-
185
-
186
- class DDPJob (Job ):
187
- def __init__ (self , job_definition : "DDPJobDefinition" , cluster : "Cluster" = None ):
188
- self .job_definition = job_definition
189
- self .cluster = cluster
190
- if self .cluster :
191
- definition , runner = job_definition ._dry_run (cluster )
192
- self ._app_handle = runner .schedule (definition )
193
- self ._runner = runner
194
- else :
195
- definition , runner = job_definition ._dry_run_no_cluster ()
196
- self ._app_handle = runner .schedule (definition )
197
- self ._runner = runner
198
- all_jobs .append (self )
199
-
200
- def status (self ) -> str :
201
- return self ._runner .status (self ._app_handle )
202
-
203
- def logs (self ) -> str :
204
- return "" .join (self ._runner .log_lines (self ._app_handle , None ))
205
-
206
- def cancel (self ):
207
- self ._runner .cancel (self ._app_handle )
0 commit comments