Skip to content

Commit 4af7aa5

Browse files
Revert "E2E composability testing (pytorch#141398)"
This reverts commit ad93aa8. Reverted pytorch#141398 on behalf of https://github.com/atalman due to Sorry need to revert pytorch#141868, we can try rebase and reland this after ([comment](pytorch#141398 (comment)))
1 parent 683ec42 commit 4af7aa5

File tree

1 file changed

+1
-194
lines changed

1 file changed

+1
-194
lines changed

test/distributed/_composable/test_composability/test_pp_composability.py

Lines changed: 1 addition & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
import torch
77
import torch.distributed.checkpoint as dcp
88
import torch.nn as nn
9-
import torch.nn.functional as F
109
from torch.distributed._tensor import DTensor
1110
from torch.distributed.checkpoint import FileSystemReader
1211
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
1312
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
1413
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
1514
from torch.distributed.checkpoint.stateful import Stateful
16-
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
15+
from torch.distributed.device_mesh import init_device_mesh
1716
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
1817
from torch.distributed.pipelining import PipelineStage
1918
from torch.distributed.pipelining.schedules import (
@@ -24,11 +23,6 @@
2423
ScheduleInterleavedZeroBubble,
2524
ScheduleLoopedBFS,
2625
)
27-
from torch.distributed.tensor.parallel import (
28-
ColwiseParallel,
29-
parallelize_module,
30-
RowwiseParallel,
31-
)
3226
from torch.nn.parallel import DistributedDataParallel as DDP
3327
from torch.testing._internal.common_cuda import TEST_MULTIGPU
3428
from torch.testing._internal.common_distributed import (
@@ -64,20 +58,6 @@ def forward(self, x):
6458
return x
6559

6660

67-
class MLPModuleEven(torch.nn.Module):
68-
def __init__(self, d_hid: int):
69-
super().__init__()
70-
self.net1 = nn.Linear(d_hid, d_hid)
71-
self.net2 = nn.Linear(d_hid, d_hid)
72-
self.net3 = nn.Linear(d_hid, d_hid * 2)
73-
74-
def forward(self, x):
75-
x = F.relu(self.net1(x))
76-
x = F.relu(self.net2(x))
77-
x = F.relu(self.net3(x))
78-
return x
79-
80-
8161
class ComposabilityTest(MultiProcessTestCase):
8262
@classmethod
8363
def backend_str(cls) -> str:
@@ -374,179 +354,6 @@ def _dcp_test(self):
374354

375355
_dcp_test(self)
376356

377-
@requires_nccl()
378-
@skip_if_lt_x_gpu(8)
379-
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs")
380-
@parametrize(
381-
"ScheduleClass",
382-
[
383-
ScheduleGPipe,
384-
Schedule1F1B,
385-
ScheduleInterleaved1F1B,
386-
ScheduleLoopedBFS,
387-
ScheduleInterleavedZeroBubble,
388-
],
389-
)
390-
@parametrize(
391-
"MixedPrecisionParam",
392-
[
393-
torch.bfloat16,
394-
torch.float32,
395-
],
396-
)
397-
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
398-
device = torch.device("cuda", self.device)
399-
torch.cuda.set_device(self.device)
400-
store = torch.distributed.FileStore(self.file_name, self.world_size)
401-
torch.distributed.init_process_group(
402-
backend="nccl",
403-
store=store,
404-
rank=self.rank,
405-
world_size=self.world_size,
406-
)
407-
dim = 8
408-
tp_size = 2
409-
pp_size = 2
410-
num_microbatches = 8
411-
dp_size = self.world_size // (tp_size * pp_size)
412-
device_mesh = init_device_mesh(
413-
"cuda",
414-
mesh_shape=(dp_size, pp_size, tp_size),
415-
mesh_dim_names=("dp", "pp", "tp"),
416-
)
417-
dp_mesh = device_mesh["dp"]
418-
tp_mesh = device_mesh["tp"]
419-
pp_mesh = device_mesh["pp"]
420-
pp_group = device_mesh["pp"].get_group()
421-
422-
# create "entire model"
423-
total_layers = 8
424-
full_model = nn.ModuleList([MLPModuleEven(dim) for _ in range(total_layers)])
425-
ref_model = nn.Sequential(*copy.deepcopy(full_model))
426-
ref_model.to(self.device)
427-
428-
# dummy loss needed just to force backwards to run in schedule step
429-
def loss_fn(y, target):
430-
return y.sum()
431-
432-
# Apply DP to stage module
433-
def apply_fsdp(partial_model):
434-
# apply FSDP
435-
mp_policy = MixedPrecisionPolicy(
436-
param_dtype=MixedPrecisionParam,
437-
reduce_dtype=torch.float32,
438-
)
439-
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
440-
for layer_id in range(len(partial_model)):
441-
fully_shard(
442-
partial_model[layer_id],
443-
**fsdp_config,
444-
reshard_after_forward=False,
445-
)
446-
dp_model = fully_shard(partial_model, **fsdp_config)
447-
return dp_model
448-
449-
def apply_tp(
450-
model: nn.Module,
451-
tp_mesh: DeviceMesh,
452-
):
453-
parallelize_plan = {
454-
"net1": ColwiseParallel(),
455-
"net2": RowwiseParallel(),
456-
"net3": ColwiseParallel(),
457-
}
458-
for layer in model:
459-
parallelize_module(layer, tp_mesh, parallelize_plan)
460-
return model
461-
462-
# Attach to a schedule
463-
if issubclass(ScheduleClass, PipelineScheduleSingle):
464-
stage_idx = pp_group.rank()
465-
partial_model = nn.Sequential(
466-
*full_model[stage_idx * 2 : stage_idx * 2 + 2]
467-
)
468-
partial_model.to(self.device)
469-
470-
tp_model = apply_tp(partial_model, tp_mesh)
471-
dp_model = apply_fsdp(tp_model)
472-
pipeline_stage = PipelineStage(
473-
dp_model,
474-
stage_idx,
475-
pp_group.size(),
476-
self.device,
477-
group=pp_group,
478-
)
479-
partial_models = [pipeline_stage.submod]
480-
pipeline_schedule = ScheduleClass(
481-
pipeline_stage,
482-
n_microbatches=num_microbatches,
483-
loss_fn=loss_fn,
484-
)
485-
else:
486-
n_virtual = 2
487-
num_stages = pp_group.size() * n_virtual
488-
stages = []
489-
for i in range(n_virtual):
490-
stage_idx = pp_group.rank() + n_virtual * i
491-
# divide the model layers by the number of stages
492-
partial_model = nn.Sequential(*full_model[stage_idx : stage_idx + 1])
493-
partial_model.to(self.device)
494-
495-
tp_model = apply_tp(partial_model, tp_mesh)
496-
dp_model = apply_fsdp(tp_model)
497-
stage = PipelineStage(
498-
dp_model,
499-
stage_idx,
500-
num_stages,
501-
self.device,
502-
group=pp_group,
503-
)
504-
505-
stages.append(stage)
506-
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
507-
pipeline_schedule = ScheduleClass(
508-
stages,
509-
n_microbatches=num_microbatches,
510-
loss_fn=loss_fn,
511-
)
512-
513-
optimizer_kwargs = {
514-
"lr": 0.01,
515-
"betas": (0.9, 0.95),
516-
"weight_decay": 0.1,
517-
"fused": False,
518-
"foreach": True,
519-
}
520-
optimizers = [
521-
torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
522-
for model in partial_models
523-
]
524-
525-
for train_step in range(5):
526-
for optimizer in optimizers:
527-
optimizer.zero_grad()
528-
inputs = torch.rand((num_microbatches, dim), device=self.device)
529-
labels = torch.rand((num_microbatches, dim), device=self.device)
530-
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
531-
if pp_mesh.get_local_rank() == 0:
532-
pipeline_schedule.step(inputs)
533-
elif is_last_stage:
534-
losses = []
535-
pipeline_schedule.step(target=labels, losses=losses)
536-
else:
537-
pipeline_schedule.step()
538-
539-
# accumulate losses across pipeline microbatches
540-
loss = (
541-
torch.mean(torch.stack(losses))
542-
if is_last_stage
543-
else torch.Tensor([-1.0])
544-
)
545-
for optimizer in optimizers:
546-
optimizer.step()
547-
548-
torch.distributed.destroy_process_group()
549-
550357

551358
instantiate_parametrized_tests(ComposabilityTest)
552359

0 commit comments

Comments
 (0)