Skip to content

Commit 42fdda1

Browse files
authored
Remove ParallelismConfig from PartialState (#3720)
* remove * style * fix * valueerror instead * add device_mesh
1 parent e23b004 commit 42fdda1

File tree

3 files changed

+27
-32
lines changed

3 files changed

+27
-32
lines changed

src/accelerate/accelerator.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -451,23 +451,28 @@ def __init__(
451451
if "recipe_handler" in handler_attr and not self.has_fp8_handler:
452452
self.has_fp8_handler = True
453453

454-
parallelism_config = self._setup_parallelism_config(parallelism_config, torch_tp_plugin)
454+
if parallelism_config is None:
455+
# TODO: Remove after deprecating tp_plugin
456+
if torch_tp_plugin is not None:
457+
parallelism_config = ParallelismConfig(tp_size=torch_tp_plugin.tp_size)
458+
elif os.environ.get("ACCELERATE_USE_PARALLELISM_CONFIG", "false").lower() == "true":
459+
parallelism_config = ParallelismConfig()
455460

456461
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
457-
kwargs["parallelism_config"] = parallelism_config
458462
self.state = AcceleratorState(
459463
mixed_precision=mixed_precision,
460464
cpu=cpu,
461465
dynamo_plugin=dynamo_plugin,
462466
deepspeed_plugin=deepspeed_plugins,
463467
fsdp_plugin=fsdp_plugin,
464468
megatron_lm_plugin=megatron_lm_plugin,
469+
parallelism_config=parallelism_config,
465470
_from_accelerator=True,
466471
**kwargs,
467472
)
468473

469474
if self.parallelism_config:
470-
self._build_torch_device_mesh(self.parallelism_config)
475+
self.state.device_mesh = parallelism_config.get_device_mesh(self.device.type)
471476
self.parallelism_config._validate_accelerator(self)
472477

473478
self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"
@@ -776,23 +781,6 @@ def should_save_model(self):
776781
# TODO: S1ro - this is a temporary solution until we figure out why `save_safe_file` is slow when not all processes
777782
return True
778783

779-
def _setup_parallelism_config(
780-
self, parallelism_config: ParallelismConfig | None, torch_tp_plugin: TorchTensorParallelPlugin | None
781-
):
782-
if parallelism_config is None:
783-
if PartialState._shared_state != {} and PartialState().parallelism_config is not None:
784-
if os.environ.get("ACCELERATE_USE_PARALLELISM_CONFIG", "false") == "true":
785-
raise ValueError(
786-
"Partial state contains a `parallelism_config` which is not None, but you configured `parallelism_config` from the `accelerate launch` CLI. We don't know which to use, please remove one of those configuration methods."
787-
)
788-
parallelism_config = PartialState().parallelism_config
789-
else:
790-
# TODO: Remove after deprecating tp_plugin
791-
tp_size = None if torch_tp_plugin is None else torch_tp_plugin.tp_size
792-
parallelism_config = ParallelismConfig(tp_size=tp_size)
793-
794-
return parallelism_config
795-
796784
@property
797785
def tensor_parallel_rank(self) -> int:
798786
"""
@@ -843,14 +831,6 @@ def data_parallel_shard_rank(self) -> int:
843831
return 0
844832
raise RuntimeError("Shard-based data parallelism is not configured. Set `parallelism_config` first.")
845833

846-
def _build_torch_device_mesh(self, parallelism_config):
847-
if PartialState._shared_state != {} and getattr(PartialState(), "device_mesh", None) is not None:
848-
device_mesh = PartialState().device_mesh
849-
else:
850-
device_mesh = parallelism_config.build_device_mesh(self.device.type)
851-
self.state.device_mesh = device_mesh
852-
PartialState().device_mesh = device_mesh
853-
854834
@contextmanager
855835
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
856836
"""

src/accelerate/parallelism_config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import warnings
1717
from dataclasses import dataclass
18-
from typing import TYPE_CHECKING, Union
18+
from typing import TYPE_CHECKING, Optional, Union
1919

2020
from torch.distributed.device_mesh import init_device_mesh
2121

@@ -66,6 +66,8 @@ class ParallelismConfig:
6666
tp_handler: Union[None, TorchTensorParallelConfig] = None
6767
cp_handler: Union[None, TorchContextParallelConfig] = None
6868

69+
device_mesh = None
70+
6971
def __repr__(self):
7072
return (
7173
"ParallelismConfig(\n "
@@ -178,7 +180,7 @@ def build_device_mesh(self, device_type: str):
178180
"""
179181
mesh = self._get_mesh()
180182
if len(mesh) == 0:
181-
return
183+
return None
182184
mesh_dim_names, mesh_shape = mesh
183185
device_mesh = init_device_mesh(
184186
device_type,
@@ -194,6 +196,20 @@ def build_device_mesh(self, device_type: str):
194196

195197
return device_mesh
196198

199+
def get_device_mesh(self, device_type: Optional[str] = None):
200+
if self.device_mesh is None:
201+
if device_type is not None:
202+
self.device_mesh = self.build_device_mesh(device_type)
203+
else:
204+
raise ("You need to pass a device_type e.g cuda to build the device mesh")
205+
else:
206+
if device_type is not None:
207+
if self.device_mesh.device_type != device_type:
208+
raise ValueError(
209+
f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
210+
)
211+
return self.device_mesh
212+
197213
def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
198214
"""Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
199215

src/accelerate/state.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,6 @@ def __init__(self, cpu: bool = False, **kwargs):
180180
if not self.initialized:
181181
self._cpu = cpu
182182
self.backend = None
183-
self.parallelism_config = kwargs.pop("parallelism_config", None)
184-
self.device_mesh = kwargs.pop("device_mesh", None)
185183
env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
186184
self.device = torch.device(env_device) if env_device is not None else None
187185
self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
@@ -919,6 +917,7 @@ def __init__(
919917
self.use_ipex = None
920918
self.torch_tp_plugin = torch_tp_plugin
921919
self.parallelism_config = parallelism_config
920+
self.device_mesh = None
922921
mixed_precision = (
923922
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
924923
if mixed_precision is None

0 commit comments

Comments
 (0)