Skip to content

Commit fdfd4c0

Browse files
fix: training speed might be incorrect (#4806)
fix #4212 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Improved the logic for accumulating and reporting training time during model training, resulting in clearer and more consistent step counting and average time logging. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ba09114 commit fdfd4c0

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

deepmd/pt/train/training.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -935,12 +935,19 @@ def log_loss_valid(_task_key="Default"):
935935
eta=eta,
936936
)
937937
)
938-
# the first training time is not accurate
939938
if (
940-
(_step_id + 1 - self.start_step) > self.disp_freq
941-
or self.num_steps - self.start_step < 2 * self.disp_freq
939+
(self.num_steps - self.start_step)
940+
<= 2 * self.disp_freq # not enough steps
941+
or (_step_id - self.start_step)
942+
>= self.disp_freq # skip first disp_freq steps
942943
):
943944
self.total_train_time += train_time
945+
if display_step_id == 1:
946+
self.timed_steps += 1
947+
else:
948+
self.timed_steps += min(
949+
self.disp_freq, _step_id - self.start_step
950+
)
944951

945952
if fout:
946953
if self.lcurve_should_print_header:
@@ -951,11 +958,14 @@ def log_loss_valid(_task_key="Default"):
951958
)
952959

953960
if (
954-
((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
955-
or (_step_id + 1) == self.num_steps
961+
(
962+
(display_step_id) % self.save_freq == 0
963+
and _step_id != self.start_step
964+
)
965+
or (display_step_id) == self.num_steps
956966
) and (self.rank == 0 or dist.get_rank() == 0):
957967
# Handle the case if rank 0 aborted and re-assigned
958-
self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt")
968+
self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt")
959969

960970
module = (
961971
self.wrapper.module
@@ -982,6 +992,7 @@ def log_loss_valid(_task_key="Default"):
982992
self.wrapper.train()
983993
self.t0 = time.time()
984994
self.total_train_time = 0.0
995+
self.timed_steps = 0
985996
for step_id in range(self.start_step, self.num_steps):
986997
step(step_id)
987998
if JIT:
@@ -1021,24 +1032,12 @@ def log_loss_valid(_task_key="Default"):
10211032
with open("checkpoint", "w") as f:
10221033
f.write(str(self.latest_model))
10231034

1024-
elapsed_batch = self.num_steps - self.start_step
1025-
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
1026-
if self.start_step >= 2 * self.disp_freq:
1027-
log.info(
1028-
"average training time: %.4f s/batch (exclude first %d batches)",
1029-
self.total_train_time
1030-
/ (
1031-
elapsed_batch // self.disp_freq * self.disp_freq
1032-
- self.disp_freq
1033-
),
1034-
self.disp_freq,
1035-
)
1036-
else:
1037-
log.info(
1038-
"average training time: %.4f s/batch",
1039-
self.total_train_time
1040-
/ (elapsed_batch // self.disp_freq * self.disp_freq),
1041-
)
1035+
if self.timing_in_training and self.timed_steps:
1036+
msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch"
1037+
excluded_steps = self.num_steps - self.start_step - self.timed_steps
1038+
if excluded_steps > 0:
1039+
msg += f" ({excluded_steps} batches excluded)"
1040+
log.info(msg)
10421041

10431042
if JIT:
10441043
pth_model_path = (

0 commit comments

Comments
 (0)