Skip to content

Commit 24d8ef1

Browse files
beniericpintaoz-awspintaozkeshav-chandakKeshav Chandak
authored and
root
committed
feat: Make DistributedConfig Extensible (aws#5039)
* feat: Make DistributedConfig Extensible * pylint * Include none types when creating config jsons for safer reference * fix: update test to account for changes * format * Add integ test * pylint * prepare release v2.240.0 * update development version to v2.240.1.dev0 * Fix key error in _send_metrics() (aws#5068) Co-authored-by: pintaoz <[email protected]> * fix: Added check for the presence of model package group before creating one (aws#5063) Co-authored-by: Keshav Chandak <[email protected]> * Use sagemaker session's s3_resource in download_folder (aws#5064) Co-authored-by: pintaoz <[email protected]> * remove union * fix merge artifact * Change dir path to distributed_drivers * update paths --------- Co-authored-by: ci <ci> Co-authored-by: pintaoz-aws <[email protected]> Co-authored-by: pintaoz <[email protected]> Co-authored-by: Keshav Chandak <[email protected]> Co-authored-by: Keshav Chandak <[email protected]>
1 parent 0c73ce0 commit 24d8ef1

File tree

22 files changed

+428
-192
lines changed

22 files changed

+428
-192
lines changed

src/sagemaker/modules/distributed.py

+69-13
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
"""Distributed module."""
1414
from __future__ import absolute_import
1515

16+
import os
17+
18+
from abc import ABC, abstractmethod
1619
from typing import Optional, Dict, Any, List
17-
from pydantic import PrivateAttr
1820
from sagemaker.modules.utils import safe_serialize
21+
from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH
1922
from sagemaker.modules.configs import BaseConfig
2023

2124

@@ -73,16 +76,37 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7376
return hyperparameters
7477

7578

76-
class DistributedConfig(BaseConfig):
77-
"""Base class for distributed training configurations."""
79+
class DistributedConfig(BaseConfig, ABC):
80+
"""Abstract base class for distributed training configurations.
81+
82+
This class defines the interface that all distributed training configurations
83+
must implement. It provides a standardized way to specify driver scripts and
84+
their locations for distributed training jobs.
85+
"""
86+
87+
@property
88+
@abstractmethod
89+
def driver_dir(self) -> str:
90+
"""Directory containing the driver script.
91+
92+
This property should return the path to the directory containing
93+
the driver script, relative to the container's working directory.
7894
79-
_type: str = PrivateAttr()
95+
Returns:
96+
str: Path to directory containing the driver script
97+
"""
8098

81-
def model_dump(self, *args, **kwargs):
82-
"""Dump the model to a dictionary."""
83-
result = super().model_dump(*args, **kwargs)
84-
result["_type"] = self._type
85-
return result
99+
@property
100+
@abstractmethod
101+
def driver_script(self) -> str:
102+
"""Name of the driver script.
103+
104+
This property should return the name of the Python script that implements
105+
the distributed training driver logic.
106+
107+
Returns:
108+
str: Name of the driver script file
109+
"""
86110

87111

88112
class Torchrun(DistributedConfig):
@@ -99,11 +123,27 @@ class Torchrun(DistributedConfig):
99123
The SageMaker Model Parallelism v2 parameters.
100124
"""
101125

102-
_type: str = PrivateAttr(default="torchrun")
103-
104126
process_count_per_node: Optional[int] = None
105127
smp: Optional["SMP"] = None
106128

129+
@property
130+
def driver_dir(self) -> str:
131+
"""Directory containing the driver script.
132+
133+
Returns:
134+
str: Path to directory containing the driver script
135+
"""
136+
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
137+
138+
@property
139+
def driver_script(self) -> str:
140+
"""Name of the driver script.
141+
142+
Returns:
143+
str: Name of the driver script file
144+
"""
145+
return "torchrun_driver.py"
146+
107147

108148
class MPI(DistributedConfig):
109149
"""MPI.
@@ -119,7 +159,23 @@ class MPI(DistributedConfig):
119159
The custom MPI options to use for the training job.
120160
"""
121161

122-
_type: str = PrivateAttr(default="mpi")
123-
124162
process_count_per_node: Optional[int] = None
125163
mpi_additional_options: Optional[List[str]] = None
164+
165+
@property
166+
def driver_dir(self) -> str:
167+
"""Directory containing the driver script.
168+
169+
Returns:
170+
str: Path to directory containing the driver script
171+
"""
172+
return os.path.join(SM_DRIVERS_LOCAL_PATH, "distributed_drivers")
173+
174+
@property
175+
def driver_script(self) -> str:
176+
"""Name of the driver script.
177+
178+
Returns:
179+
str: Name of the driver script
180+
"""
181+
return "mpi_driver.py"

src/sagemaker/modules/templates.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,12 @@
2121

2222
EXECUTE_BASIC_SCRIPT_DRIVER = """
2323
echo "Running Basic Script driver"
24-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py
24+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/basic_script_driver.py
2525
"""
2626

27-
EXEUCTE_TORCHRUN_DRIVER = """
28-
echo "Running Torchrun driver"
29-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py
30-
"""
31-
32-
EXECUTE_MPI_DRIVER = """
33-
echo "Running MPI driver"
34-
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py
27+
EXEUCTE_DISTRIBUTED_DRIVER = """
28+
echo "Running {driver_name} Driver"
29+
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script}
3530
"""
3631

3732
TRAIN_SCRIPT_TEMPLATE = """

src/sagemaker/modules/train/container_drivers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Sagemaker modules container_drivers directory."""
13+
"""Sagemaker modules container drivers directory."""
1414
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container drivers - common directory."""
14+
from __future__ import absolute_import

src/sagemaker/modules/train/container_drivers/utils.py renamed to src/sagemaker/modules/train/container_drivers/common/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME
9999
return hyperparameters_dict
100100

101101

102-
def get_process_count(distributed_dict: Dict[str, Any]) -> int:
102+
def get_process_count(process_count: Optional[int] = None) -> int:
103103
"""Get the number of processes to run on each node in the training job."""
104104
return (
105-
int(distributed_dict.get("process_count_per_node", 0))
105+
process_count
106106
or int(os.environ.get("SM_NUM_GPUS", 0))
107107
or int(os.environ.get("SM_NUM_NEURONS", 0))
108108
or 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker modules container drivers - drivers directory."""
14+
from __future__ import absolute_import

src/sagemaker/modules/train/container_drivers/basic_script_driver.py renamed to src/sagemaker/modules/train/container_drivers/distributed_drivers/basic_script_driver.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@
1313
"""This module is the entry point for the Basic Script Driver."""
1414
from __future__ import absolute_import
1515

16+
import os
1617
import sys
18+
import json
1719
import shlex
1820

21+
from pathlib import Path
1922
from typing import List
2023

21-
from utils import (
24+
sys.path.insert(0, str(Path(__file__).parent.parent))
25+
26+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
2227
logger,
2328
get_python_executable,
24-
read_source_code_json,
25-
read_hyperparameters_json,
2629
execute_commands,
2730
write_failure_file,
2831
hyperparameters_to_cli_args,
@@ -31,11 +34,10 @@
3134

3235
def create_commands() -> List[str]:
3336
"""Create the commands to execute."""
34-
source_code = read_source_code_json()
35-
hyperparameters = read_hyperparameters_json()
37+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
38+
hyperparameters = json.loads(os.environ["SM_HPS"])
3639
python_executable = get_python_executable()
3740

38-
entry_script = source_code["entry_script"]
3941
args = hyperparameters_to_cli_args(hyperparameters)
4042
if entry_script.endswith(".py"):
4143
commands = [python_executable, entry_script]

src/sagemaker/modules/train/container_drivers/mpi_driver.py renamed to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_driver.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,8 @@
1616
import os
1717
import sys
1818
import json
19+
from pathlib import Path
1920

20-
from utils import (
21-
logger,
22-
read_source_code_json,
23-
read_distributed_json,
24-
read_hyperparameters_json,
25-
hyperparameters_to_cli_args,
26-
get_process_count,
27-
execute_commands,
28-
write_failure_file,
29-
USER_CODE_PATH,
30-
)
3121
from mpi_utils import (
3222
start_sshd_daemon,
3323
bootstrap_master_node,
@@ -38,6 +28,16 @@
3828
)
3929

4030

31+
sys.path.insert(0, str(Path(__file__).parent.parent))
32+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
33+
logger,
34+
hyperparameters_to_cli_args,
35+
get_process_count,
36+
execute_commands,
37+
write_failure_file,
38+
)
39+
40+
4141
def main():
4242
"""Main function for the MPI driver script.
4343
@@ -58,9 +58,9 @@ def main():
5858
5. Exit
5959
6060
"""
61-
source_code = read_source_code_json()
62-
distribution = read_distributed_json()
63-
hyperparameters = read_hyperparameters_json()
61+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
62+
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
63+
hyperparameters = json.loads(os.environ["SM_HPS"])
6464

6565
sm_current_host = os.environ["SM_CURRENT_HOST"]
6666
sm_hosts = json.loads(os.environ["SM_HOSTS"])
@@ -77,7 +77,8 @@ def main():
7777

7878
host_list = json.loads(os.environ["SM_HOSTS"])
7979
host_count = int(os.environ["SM_HOST_COUNT"])
80-
process_count = get_process_count(distribution)
80+
process_count = int(distributed_config["process_count_per_node"] or 0)
81+
process_count = get_process_count(process_count)
8182

8283
if process_count > 1:
8384
host_list = ["{}:{}".format(host, process_count) for host in host_list]
@@ -86,8 +87,8 @@ def main():
8687
host_count=host_count,
8788
host_list=host_list,
8889
num_processes=process_count,
89-
additional_options=distribution.get("mpi_additional_options", []),
90-
entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]),
90+
additional_options=distributed_config["mpi_additional_options"] or [],
91+
entry_script_path=entry_script,
9192
)
9293

9394
args = hyperparameters_to_cli_args(hyperparameters)

src/sagemaker/modules/train/container_drivers/mpi_utils.py renamed to src/sagemaker/modules/train/container_drivers/distributed_drivers/mpi_utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,23 @@
1414
from __future__ import absolute_import
1515

1616
import os
17+
import sys
1718
import subprocess
1819
import time
20+
21+
from pathlib import Path
1922
from typing import List
2023

2124
import paramiko
22-
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger
25+
26+
sys.path.insert(0, str(Path(__file__).parent.parent))
27+
28+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
29+
SM_EFA_NCCL_INSTANCES,
30+
SM_EFA_RDMA_INSTANCES,
31+
get_python_executable,
32+
logger,
33+
)
2334

2435
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
2536
READY_FILE = "/tmp/ready.%s"

src/sagemaker/modules/train/container_drivers/torchrun_driver.py renamed to src/sagemaker/modules/train/container_drivers/distributed_drivers/torchrun_driver.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515

1616
import os
1717
import sys
18+
import json
1819

20+
from pathlib import Path
1921
from typing import List, Tuple
2022

21-
from utils import (
23+
sys.path.insert(0, str(Path(__file__).parent.parent))
24+
25+
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
2226
logger,
23-
read_source_code_json,
24-
read_distributed_json,
25-
read_hyperparameters_json,
2627
hyperparameters_to_cli_args,
2728
get_process_count,
2829
get_python_executable,
2930
execute_commands,
3031
write_failure_file,
31-
USER_CODE_PATH,
3232
SM_EFA_NCCL_INSTANCES,
3333
SM_EFA_RDMA_INSTANCES,
3434
)
@@ -65,11 +65,12 @@ def setup_env():
6565

6666
def create_commands():
6767
"""Create the Torch Distributed command to execute"""
68-
source_code = read_source_code_json()
69-
distribution = read_distributed_json()
70-
hyperparameters = read_hyperparameters_json()
68+
entry_script = os.environ["SM_ENTRY_SCRIPT"]
69+
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
70+
hyperparameters = json.loads(os.environ["SM_HPS"])
7171

72-
process_count = get_process_count(distribution)
72+
process_count = int(distributed_config["process_count_per_node"] or 0)
73+
process_count = get_process_count(process_count)
7374
host_count = int(os.environ["SM_HOST_COUNT"])
7475

7576
torch_cmd = []
@@ -94,7 +95,7 @@ def create_commands():
9495
]
9596
)
9697

97-
torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])])
98+
torch_cmd.extend([entry_script])
9899

99100
args = hyperparameters_to_cli_args(hyperparameters)
100101
torch_cmd += args

src/sagemaker/modules/train/container_drivers/scripts/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Sagemaker modules scripts directory."""
13+
"""Sagemaker modules container drivers - scripts directory."""
1414
from __future__ import absolute_import

0 commit comments

Comments
 (0)