Skip to content

Commit 2fa6d83

Browse files
authored
Do not aggregate the losses since last log step (#779)
Fixes #763
1 parent 82f7387 commit 2fa6d83

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

torchtitan/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,20 @@ def get_device_info():
3434
device_type, device_module = get_device_info()
3535

3636

37-
def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
38-
tensor = torch.tensor(x).to(device_type)
39-
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item()
37+
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
38+
if isinstance(x, DTensor):
39+
# functional collectives do not support DTensor inputs
40+
x = x.full_tensor()
41+
assert x.numel() == 1 # required by `.item()`
42+
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
4043

4144

42-
def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
43-
tensor = torch.tensor(x).to(device_type)
44-
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item()
45+
def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float:
46+
return dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh)
47+
48+
49+
def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float:
50+
return dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh)
4551

4652

4753
def _warn_overwrite_env(env, val):

train.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def loss_fn(pred, labels):
231231
)
232232

233233
# variables used to keep info for metrics logging
234-
losses_since_last_log = []
235234
ntokens_since_last_log = 0
236235
data_loading_times = []
237236
time_last_log = time.perf_counter()
@@ -295,10 +294,11 @@ def loss_fn(pred, labels):
295294
pp_schedule.step()
296295

297296
# accumulate losses across pipeline microbatches
297+
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
298298
loss = (
299-
torch.mean(torch.stack(losses))
299+
torch.mean(torch.stack(losses)).to(device)
300300
if is_last_stage
301-
else torch.Tensor([-1.0])
301+
else torch.tensor([-1.0], device=device)
302302
)
303303
else:
304304
# Non-PP forward / backward
@@ -330,26 +330,23 @@ def loss_fn(pred, labels):
330330
# it issues a single all-reduce for all parameters at once for better performance
331331
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)
332332

333-
losses_since_last_log.append(loss)
334-
335333
# log metrics
336334
if (
337335
train_state.step == 1
338336
or train_state.step % job_config.metrics.log_freq == 0
339337
):
340-
losses = [loss.item() for loss in losses_since_last_log]
341-
avg_loss, max_loss = sum(losses) / len(losses), max(losses)
342338
if (
343339
parallel_dims.dp_replicate_enabled
344340
or parallel_dims.dp_shard_enabled
345341
or parallel_dims.cp_enabled
346342
):
343+
loss = loss.detach()
347344
global_avg_loss, global_max_loss = (
348-
utils.dist_mean(avg_loss, world_mesh["dp_cp"]),
349-
utils.dist_max(max_loss, world_mesh["dp_cp"]),
345+
utils.dist_mean(loss, world_mesh["dp_cp"]),
346+
utils.dist_max(loss, world_mesh["dp_cp"]),
350347
)
351348
else:
352-
global_avg_loss, global_max_loss = avg_loss, max_loss
349+
global_avg_loss = global_max_loss = loss.item()
353350

354351
# update train state
355352
train_state.log_steps.append(train_state.step)
@@ -399,7 +396,6 @@ def loss_fn(pred, labels):
399396
f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
400397
)
401398

402-
losses_since_last_log.clear()
403399
ntokens_since_last_log = 0
404400
data_loading_times.clear()
405401
time_last_log = time.perf_counter()

0 commit comments

Comments
 (0)