@@ -78,32 +78,19 @@ def __init__(self, job_config: JobConfig):
78
78
self .device = torch .device (f"{ device_type } :{ int (os .environ ['LOCAL_RANK' ])} " )
79
79
# Device has to be set before creating TorchFT manager.
80
80
device_module .set_device (self .device )
81
- ft_manager = ft .init_ft_manager (job_config )
82
81
83
82
# init distributed
84
83
world_size = int (os .environ ["WORLD_SIZE" ])
85
84
parallelism_config = job_config .parallelism
86
- if not ft_manager .enabled :
87
- self .parallel_dims = parallel_dims = ParallelDims (
88
- dp_shard = parallelism_config .data_parallel_shard_degree ,
89
- dp_replicate = parallelism_config .data_parallel_replicate_degree ,
90
- cp = parallelism_config .context_parallel_degree ,
91
- tp = parallelism_config .tensor_parallel_degree ,
92
- pp = parallelism_config .pipeline_parallel_degree ,
93
- world_size = world_size ,
94
- enable_loss_parallel = not parallelism_config .disable_loss_parallel ,
95
- )
96
- else :
97
- self .parallel_dims = parallel_dims = ft .FTParallelDims (
98
- dp_shard = parallelism_config .data_parallel_shard_degree ,
99
- dp_replicate = parallelism_config .data_parallel_replicate_degree ,
100
- cp = parallelism_config .context_parallel_degree ,
101
- tp = parallelism_config .tensor_parallel_degree ,
102
- pp = parallelism_config .pipeline_parallel_degree ,
103
- world_size = world_size ,
104
- enable_loss_parallel = not parallelism_config .disable_loss_parallel ,
105
- ft_manager = ft_manager ,
106
- )
85
+ self .parallel_dims = parallel_dims = ParallelDims (
86
+ dp_shard = parallelism_config .data_parallel_shard_degree ,
87
+ dp_replicate = parallelism_config .data_parallel_replicate_degree ,
88
+ cp = parallelism_config .context_parallel_degree ,
89
+ tp = parallelism_config .tensor_parallel_degree ,
90
+ pp = parallelism_config .pipeline_parallel_degree ,
91
+ world_size = world_size ,
92
+ enable_loss_parallel = not parallelism_config .disable_loss_parallel ,
93
+ )
107
94
dist_utils .init_distributed (job_config )
108
95
109
96
# build meshes
@@ -114,6 +101,12 @@ def __init__(self, job_config: JobConfig):
114
101
else :
115
102
dp_degree , dp_rank = 1 , 0
116
103
104
+ self .ft_manager = ft .init_ft_manager (job_config )
105
+ # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
106
+ # dataloader must be changed.
107
+ if self .ft_manager .enabled :
108
+ dp_degree , dp_rank = self .ft_manager .get_dp_info (dp_degree , dp_rank )
109
+
117
110
# Set random seed, and maybe enable deterministic mode
118
111
# (mainly for debugging, expect perf loss).
119
112
dist_utils .set_determinism (
@@ -131,11 +124,6 @@ def __init__(self, job_config: JobConfig):
131
124
else None
132
125
)
133
126
134
- # If TorchFT is enabled, the dp_rank and dp_degree, which are used for
135
- # dataloader must be changed.
136
- if ft_manager .enabled :
137
- dp_degree , dp_rank = ft_manager .get_dp_info (dp_degree , dp_rank )
138
-
139
127
self .dataloader = self .train_spec .build_dataloader_fn (
140
128
dp_world_size = dp_degree ,
141
129
dp_rank = dp_rank ,
@@ -241,6 +229,9 @@ def __init__(self, job_config: JobConfig):
241
229
242
230
self .model_parts = [model ]
243
231
232
+ if self .ft_manager .enabled :
233
+ self .ft_manager .set_all_reduce_hook (self .model_parts )
234
+
244
235
# initialize device memory monitor and get peak flops for MFU calculation
245
236
device_memory_monitor = self .metrics_processor .device_memory_monitor
246
237
gpu_peak_flops = utils .get_peak_flops (device_memory_monitor .device_name )
@@ -254,7 +245,7 @@ def __init__(self, job_config: JobConfig):
254
245
255
246
# build optimizer after applying parallelisms to the model
256
247
self .optimizers = self .train_spec .build_optimizers_fn (
257
- self .model_parts , job_config , ft_manager
248
+ self .model_parts , job_config , self . ft_manager
258
249
)
259
250
self .lr_schedulers = self .train_spec .build_lr_schedulers_fn (
260
251
self .optimizers , job_config
@@ -280,7 +271,7 @@ def __init__(self, job_config: JobConfig):
280
271
lr_schedulers = self .lr_schedulers ,
281
272
states = {"train_state" : self },
282
273
job_config = job_config ,
283
- ft_manager = ft_manager ,
274
+ ft_manager = self . ft_manager ,
284
275
)
285
276
286
277
self .train_context = dist_utils .get_train_context (
@@ -384,11 +375,13 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
384
375
parallel_dims .dp_replicate_enabled
385
376
or parallel_dims .dp_shard_enabled
386
377
or parallel_dims .cp_enabled
378
+ or self .ft_manager .enabled
387
379
):
388
380
loss = loss .detach ()
381
+ ft_pg = self .ft_manager .replicate_pg if self .ft_manager .enabled else None
389
382
global_avg_loss , global_max_loss = (
390
- dist_utils .dist_mean (loss , world_mesh ["dp_cp" ]),
391
- dist_utils .dist_max (loss , world_mesh ["dp_cp" ]),
383
+ dist_utils .dist_mean (loss , world_mesh ["dp_cp" ], ft_pg ),
384
+ dist_utils .dist_max (loss , world_mesh ["dp_cp" ], ft_pg ),
392
385
)
393
386
else :
394
387
global_avg_loss = global_max_loss = loss .detach ().item ()
0 commit comments