@@ -231,7 +231,6 @@ def loss_fn(pred, labels):
231
231
)
232
232
233
233
# variables used to keep info for metrics logging
234
- losses_since_last_log = []
235
234
ntokens_since_last_log = 0
236
235
data_loading_times = []
237
236
time_last_log = time .perf_counter ()
@@ -295,10 +294,11 @@ def loss_fn(pred, labels):
295
294
pp_schedule .step ()
296
295
297
296
# accumulate losses across pipeline microbatches
297
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
298
298
loss = (
299
- torch .mean (torch .stack (losses ))
299
+ torch .mean (torch .stack (losses )). to ( device )
300
300
if is_last_stage
301
- else torch .Tensor ([- 1.0 ])
301
+ else torch .tensor ([- 1.0 ], device = device )
302
302
)
303
303
else :
304
304
# Non-PP forward / backward
@@ -330,26 +330,23 @@ def loss_fn(pred, labels):
330
330
# it issues a single all-reduce for all parameters at once for better performance
331
331
float8_handler .precompute_float8_dynamic_scale_for_fsdp (model_parts )
332
332
333
- losses_since_last_log .append (loss )
334
-
335
333
# log metrics
336
334
if (
337
335
train_state .step == 1
338
336
or train_state .step % job_config .metrics .log_freq == 0
339
337
):
340
- losses = [loss .item () for loss in losses_since_last_log ]
341
- avg_loss , max_loss = sum (losses ) / len (losses ), max (losses )
342
338
if (
343
339
parallel_dims .dp_replicate_enabled
344
340
or parallel_dims .dp_shard_enabled
345
341
or parallel_dims .cp_enabled
346
342
):
343
+ loss = loss .detach ()
347
344
global_avg_loss , global_max_loss = (
348
- utils .dist_mean (avg_loss , world_mesh ["dp_cp" ]),
349
- utils .dist_max (max_loss , world_mesh ["dp_cp" ]),
345
+ utils .dist_mean (loss , world_mesh ["dp_cp" ]),
346
+ utils .dist_max (loss , world_mesh ["dp_cp" ]),
350
347
)
351
348
else :
352
- global_avg_loss , global_max_loss = avg_loss , max_loss
349
+ global_avg_loss = global_max_loss = loss . item ()
353
350
354
351
# update train state
355
352
train_state .log_steps .append (train_state .step )
@@ -399,7 +396,6 @@ def loss_fn(pred, labels):
399
396
f"{ color .magenta } mfu: { mfu :.2f} %{ color .reset } "
400
397
)
401
398
402
- losses_since_last_log .clear ()
403
399
ntokens_since_last_log = 0
404
400
data_loading_times .clear ()
405
401
time_last_log = time .perf_counter ()
0 commit comments