Skip to content

Commit 753f82a

Browse files
xingyousongcopybara-github
authored andcommitted
Small rearrangement fix
PiperOrigin-RevId: 690797914
1 parent de5453c commit 753f82a

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

optformer/embed_then_regress/train.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,12 @@ def train(
207207
eff_step = int(unreplicate(train_state.step)) // grad_accum_steps
208208

209209
while eff_step < train_config.max_steps:
210-
all_train_metrics = []
211-
for _ in range(grad_accum_steps):
212-
train_state, train_metrics = p_train_step(train_state, next(train_it))
213-
all_train_metrics.append(train_metrics)
214-
writer.write_scalars(eff_step, aggregate_metrics(all_train_metrics))
210+
if eff_step % train_config.checkpoint_interval == 0:
211+
ckpt_train_state = unreplicate(train_state)
212+
checkpoint_manager.save(
213+
eff_step,
214+
items=dict(train_state=jax.tree.map(np.array, ckpt_train_state)),
215+
)
215216

216217
if eff_step % train_config.validation_interval == 0:
217218
all_valid_metrics = [
@@ -220,10 +221,10 @@ def train(
220221
]
221222
writer.write_scalars(eff_step, aggregate_metrics(all_valid_metrics))
222223

223-
if eff_step % train_config.checkpoint_interval == 0:
224-
ckpt_train_state = unreplicate(train_state)
225-
checkpoint_manager.save(
226-
eff_step,
227-
items=dict(train_state=jax.tree.map(np.array, ckpt_train_state)),
228-
)
224+
all_train_metrics = []
225+
for _ in range(grad_accum_steps):
226+
train_state, train_metrics = p_train_step(train_state, next(train_it))
227+
all_train_metrics.append(train_metrics)
228+
writer.write_scalars(eff_step, aggregate_metrics(all_train_metrics))
229+
229230
eff_step = int(unreplicate(train_state.step)) // grad_accum_steps

0 commit comments

Comments
 (0)