Skip to content

Commit 8fcb3ba

Browse files
authored
[Distributed] create model on meta device (#1227)
1 parent 77bac00 commit 8fcb3ba

File tree

3 files changed

+27
-35
lines changed

3 files changed

+27
-35
lines changed

dist_run.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,26 @@ def main(args):
328328
config.stage_idx = pp_rank
329329
config.n_stages = pp_degree
330330

331-
with device:
331+
with torch.device("meta"):
332332
# TODO: we should create model instead of Transformer
333333
model = Transformer(config)
334334

335335
# Distribute model on TP mesh
336+
# (Surprisingly, this works even though model is on meta device and mesh is of
337+
# cuda devices)
336338
model.distribute(tp_mesh)
337339
if rank == 0:
338340
logger.info(f"Model: {model}")
339341

342+
# Load weights
343+
logger.info(f"Loading weights for {pp_rank=} on {device=}")
344+
with CUDATrackTime() as timer:
345+
_load_model_weights(model, distribution, device=device, model_config=config)
346+
347+
logger.info(
348+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
349+
)
350+
340351
# Batch size. Since we push batches dynamically through the pipeline rather
341352
# than chunking them, this is effectively micro-batch size in pipeline
342353
# sense. Thus it is interchangeable with micro-batch size below.
@@ -352,17 +363,8 @@ def main(args):
352363
# lanes.
353364
# TODO: bump up the lane count
354365
pipeline_lanes = 1
355-
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)
356-
357-
# Load weights
358-
logger.info(f"Loading weights for {pp_rank=} on {device=}")
359-
with CUDATrackTime() as timer:
360-
_load_model_weights(model, distribution, device=device, model_config=config)
361-
model.to(device)
362-
363-
logger.info(
364-
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
365-
)
366+
with device:
367+
model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes)
366368

367369
# info on stage size and params
368370
stage_size = get_module_size(model)

torchchat/distributed/dtensor_utils.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,17 @@
88
logger = SingletonLogger.get_logger()
99

1010

11-
12-
def is_dtensor(tensor):
13-
"""Check if a tensor is a DTensor by class or has a placements attribute (not sure if we want to use attr check)"""
14-
return isinstance(tensor, DTensor) or hasattr(tensor, "placements")
15-
16-
17-
def load_into_dtensor(weight_tensor, model_dtensor):
11+
def convert_to_dtensor(weight_tensor, dtensor_template):
1812
"""Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it"""
19-
weight_tensor = weight_tensor.to(model_dtensor.device)
2013

21-
if weight_tensor.shape != model_dtensor.shape:
14+
if weight_tensor.shape != dtensor_template.shape:
2215
raise ValueError(
2316
f"Shape mismatch: weight tensor shape {weight_tensor.shape} "
24-
f"doesn't match DTensor shape {model_dtensor.shape}"
17+
f"doesn't match DTensor shape {dtensor_template.shape}"
2518
)
2619

27-
placements = model_dtensor.placements
28-
mesh = model_dtensor.device_mesh
20+
placements = dtensor_template.placements
21+
mesh = dtensor_template.device_mesh
2922
mesh_dims = mesh.ndim
3023

3124
for placement in placements:

torchchat/distributed/safetensor_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from torch.nn import Module
1414
from typing import Dict, Tuple, Set, Optional
1515

16-
17-
from torchchat.distributed.dtensor_utils import is_dtensor, load_into_dtensor
16+
from torch.distributed._tensor import DTensor
17+
from torchchat.distributed.dtensor_utils import convert_to_dtensor
1818

1919

2020
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
@@ -284,9 +284,7 @@ def update_state_dict(
284284
continue
285285

286286
checkpoint_tensor = checkpoint[old_param]
287-
stage_tensor = state_dict[param]
288-
289-
stage_is_dtensor = is_dtensor(stage_tensor)
287+
model_tensor = state_dict[param]
290288

291289
if "wq" in param:
292290
checkpoint_tensor = permute_weight_to_attn_heads(
@@ -297,17 +295,16 @@ def update_state_dict(
297295
checkpoint_tensor, num_local_heads, head_dim, dim
298296
)
299297

298+
# Move checkpoint tensor to desired device
299+
checkpoint_tensor = checkpoint_tensor.to(device)
300+
300301
# here we need to check if the tensor is a DTensor and if so, adjust the
301302
# shape and placement to match the model DTensor.
302-
if stage_is_dtensor:
303-
model_tensor = load_into_dtensor(checkpoint_tensor, stage_tensor)
304-
# logger.info(f"DTensor: Loaded {param} into {model_tensor=}")
305-
state_dict[param] = model_tensor
303+
if isinstance(model_tensor, DTensor):
304+
state_dict[param] = convert_to_dtensor(checkpoint_tensor, model_tensor)
306305
count_dtensors_loaded += 1
307-
308306
else:
309307
# regular tensor, just update directly
310-
checkpoint_tensor = checkpoint_tensor.to(device)
311308
state_dict[param] = checkpoint_tensor
312309

313310
# ensure matching dtypes

0 commit comments

Comments
 (0)