|
6 | 6 | import torch
|
7 | 7 | import torch.distributed.checkpoint as dcp
|
8 | 8 | import torch.nn as nn
|
9 |
| -import torch.nn.functional as F |
10 | 9 | from torch.distributed._tensor import DTensor
|
11 | 10 | from torch.distributed.checkpoint import FileSystemReader
|
12 | 11 | from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
13 | 12 | from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
14 | 13 | from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
15 | 14 | from torch.distributed.checkpoint.stateful import Stateful
|
16 |
| -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
| 15 | +from torch.distributed.device_mesh import init_device_mesh |
17 | 16 | from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
18 | 17 | from torch.distributed.pipelining import PipelineStage
|
19 | 18 | from torch.distributed.pipelining.schedules import (
|
|
24 | 23 | ScheduleInterleavedZeroBubble,
|
25 | 24 | ScheduleLoopedBFS,
|
26 | 25 | )
|
27 |
| -from torch.distributed.tensor.parallel import ( |
28 |
| - ColwiseParallel, |
29 |
| - parallelize_module, |
30 |
| - RowwiseParallel, |
31 |
| -) |
32 | 26 | from torch.nn.parallel import DistributedDataParallel as DDP
|
33 | 27 | from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
34 | 28 | from torch.testing._internal.common_distributed import (
|
@@ -64,20 +58,6 @@ def forward(self, x):
|
64 | 58 | return x
|
65 | 59 |
|
66 | 60 |
|
67 |
| -class MLPModuleEven(torch.nn.Module): |
68 |
| - def __init__(self, d_hid: int): |
69 |
| - super().__init__() |
70 |
| - self.net1 = nn.Linear(d_hid, d_hid) |
71 |
| - self.net2 = nn.Linear(d_hid, d_hid) |
72 |
| - self.net3 = nn.Linear(d_hid, d_hid * 2) |
73 |
| - |
74 |
| - def forward(self, x): |
75 |
| - x = F.relu(self.net1(x)) |
76 |
| - x = F.relu(self.net2(x)) |
77 |
| - x = F.relu(self.net3(x)) |
78 |
| - return x |
79 |
| - |
80 |
| - |
81 | 61 | class ComposabilityTest(MultiProcessTestCase):
|
82 | 62 | @classmethod
|
83 | 63 | def backend_str(cls) -> str:
|
@@ -374,179 +354,6 @@ def _dcp_test(self):
|
374 | 354 |
|
375 | 355 | _dcp_test(self)
|
376 | 356 |
|
377 |
| - @requires_nccl() |
378 |
| - @skip_if_lt_x_gpu(8) |
379 |
| - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs") |
380 |
| - @parametrize( |
381 |
| - "ScheduleClass", |
382 |
| - [ |
383 |
| - ScheduleGPipe, |
384 |
| - Schedule1F1B, |
385 |
| - ScheduleInterleaved1F1B, |
386 |
| - ScheduleLoopedBFS, |
387 |
| - ScheduleInterleavedZeroBubble, |
388 |
| - ], |
389 |
| - ) |
390 |
| - @parametrize( |
391 |
| - "MixedPrecisionParam", |
392 |
| - [ |
393 |
| - torch.bfloat16, |
394 |
| - torch.float32, |
395 |
| - ], |
396 |
| - ) |
397 |
| - def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam): |
398 |
| - device = torch.device("cuda", self.device) |
399 |
| - torch.cuda.set_device(self.device) |
400 |
| - store = torch.distributed.FileStore(self.file_name, self.world_size) |
401 |
| - torch.distributed.init_process_group( |
402 |
| - backend="nccl", |
403 |
| - store=store, |
404 |
| - rank=self.rank, |
405 |
| - world_size=self.world_size, |
406 |
| - ) |
407 |
| - dim = 8 |
408 |
| - tp_size = 2 |
409 |
| - pp_size = 2 |
410 |
| - num_microbatches = 8 |
411 |
| - dp_size = self.world_size // (tp_size * pp_size) |
412 |
| - device_mesh = init_device_mesh( |
413 |
| - "cuda", |
414 |
| - mesh_shape=(dp_size, pp_size, tp_size), |
415 |
| - mesh_dim_names=("dp", "pp", "tp"), |
416 |
| - ) |
417 |
| - dp_mesh = device_mesh["dp"] |
418 |
| - tp_mesh = device_mesh["tp"] |
419 |
| - pp_mesh = device_mesh["pp"] |
420 |
| - pp_group = device_mesh["pp"].get_group() |
421 |
| - |
422 |
| - # create "entire model" |
423 |
| - total_layers = 8 |
424 |
| - full_model = nn.ModuleList([MLPModuleEven(dim) for _ in range(total_layers)]) |
425 |
| - ref_model = nn.Sequential(*copy.deepcopy(full_model)) |
426 |
| - ref_model.to(self.device) |
427 |
| - |
428 |
| - # dummy loss needed just to force backwards to run in schedule step |
429 |
| - def loss_fn(y, target): |
430 |
| - return y.sum() |
431 |
| - |
432 |
| - # Apply DP to stage module |
433 |
| - def apply_fsdp(partial_model): |
434 |
| - # apply FSDP |
435 |
| - mp_policy = MixedPrecisionPolicy( |
436 |
| - param_dtype=MixedPrecisionParam, |
437 |
| - reduce_dtype=torch.float32, |
438 |
| - ) |
439 |
| - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
440 |
| - for layer_id in range(len(partial_model)): |
441 |
| - fully_shard( |
442 |
| - partial_model[layer_id], |
443 |
| - **fsdp_config, |
444 |
| - reshard_after_forward=False, |
445 |
| - ) |
446 |
| - dp_model = fully_shard(partial_model, **fsdp_config) |
447 |
| - return dp_model |
448 |
| - |
449 |
| - def apply_tp( |
450 |
| - model: nn.Module, |
451 |
| - tp_mesh: DeviceMesh, |
452 |
| - ): |
453 |
| - parallelize_plan = { |
454 |
| - "net1": ColwiseParallel(), |
455 |
| - "net2": RowwiseParallel(), |
456 |
| - "net3": ColwiseParallel(), |
457 |
| - } |
458 |
| - for layer in model: |
459 |
| - parallelize_module(layer, tp_mesh, parallelize_plan) |
460 |
| - return model |
461 |
| - |
462 |
| - # Attach to a schedule |
463 |
| - if issubclass(ScheduleClass, PipelineScheduleSingle): |
464 |
| - stage_idx = pp_group.rank() |
465 |
| - partial_model = nn.Sequential( |
466 |
| - *full_model[stage_idx * 2 : stage_idx * 2 + 2] |
467 |
| - ) |
468 |
| - partial_model.to(self.device) |
469 |
| - |
470 |
| - tp_model = apply_tp(partial_model, tp_mesh) |
471 |
| - dp_model = apply_fsdp(tp_model) |
472 |
| - pipeline_stage = PipelineStage( |
473 |
| - dp_model, |
474 |
| - stage_idx, |
475 |
| - pp_group.size(), |
476 |
| - self.device, |
477 |
| - group=pp_group, |
478 |
| - ) |
479 |
| - partial_models = [pipeline_stage.submod] |
480 |
| - pipeline_schedule = ScheduleClass( |
481 |
| - pipeline_stage, |
482 |
| - n_microbatches=num_microbatches, |
483 |
| - loss_fn=loss_fn, |
484 |
| - ) |
485 |
| - else: |
486 |
| - n_virtual = 2 |
487 |
| - num_stages = pp_group.size() * n_virtual |
488 |
| - stages = [] |
489 |
| - for i in range(n_virtual): |
490 |
| - stage_idx = pp_group.rank() + n_virtual * i |
491 |
| - # divide the model layers by the number of stages |
492 |
| - partial_model = nn.Sequential(*full_model[stage_idx : stage_idx + 1]) |
493 |
| - partial_model.to(self.device) |
494 |
| - |
495 |
| - tp_model = apply_tp(partial_model, tp_mesh) |
496 |
| - dp_model = apply_fsdp(tp_model) |
497 |
| - stage = PipelineStage( |
498 |
| - dp_model, |
499 |
| - stage_idx, |
500 |
| - num_stages, |
501 |
| - self.device, |
502 |
| - group=pp_group, |
503 |
| - ) |
504 |
| - |
505 |
| - stages.append(stage) |
506 |
| - partial_models = [pipeline_stage.submod for pipeline_stage in stages] |
507 |
| - pipeline_schedule = ScheduleClass( |
508 |
| - stages, |
509 |
| - n_microbatches=num_microbatches, |
510 |
| - loss_fn=loss_fn, |
511 |
| - ) |
512 |
| - |
513 |
| - optimizer_kwargs = { |
514 |
| - "lr": 0.01, |
515 |
| - "betas": (0.9, 0.95), |
516 |
| - "weight_decay": 0.1, |
517 |
| - "fused": False, |
518 |
| - "foreach": True, |
519 |
| - } |
520 |
| - optimizers = [ |
521 |
| - torch.optim.AdamW(model.parameters(), **optimizer_kwargs) |
522 |
| - for model in partial_models |
523 |
| - ] |
524 |
| - |
525 |
| - for train_step in range(5): |
526 |
| - for optimizer in optimizers: |
527 |
| - optimizer.zero_grad() |
528 |
| - inputs = torch.rand((num_microbatches, dim), device=self.device) |
529 |
| - labels = torch.rand((num_microbatches, dim), device=self.device) |
530 |
| - is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 |
531 |
| - if pp_mesh.get_local_rank() == 0: |
532 |
| - pipeline_schedule.step(inputs) |
533 |
| - elif is_last_stage: |
534 |
| - losses = [] |
535 |
| - pipeline_schedule.step(target=labels, losses=losses) |
536 |
| - else: |
537 |
| - pipeline_schedule.step() |
538 |
| - |
539 |
| - # accumulate losses across pipeline microbatches |
540 |
| - loss = ( |
541 |
| - torch.mean(torch.stack(losses)) |
542 |
| - if is_last_stage |
543 |
| - else torch.Tensor([-1.0]) |
544 |
| - ) |
545 |
| - for optimizer in optimizers: |
546 |
| - optimizer.step() |
547 |
| - |
548 |
| - torch.distributed.destroy_process_group() |
549 |
| - |
550 | 357 |
|
551 | 358 | instantiate_parametrized_tests(ComposabilityTest)
|
552 | 359 |
|
|
0 commit comments