@@ -207,11 +207,12 @@ def train(
207207 eff_step = int (unreplicate (train_state .step )) // grad_accum_steps
208208
209209 while eff_step < train_config .max_steps :
210- all_train_metrics = []
211- for _ in range (grad_accum_steps ):
212- train_state , train_metrics = p_train_step (train_state , next (train_it ))
213- all_train_metrics .append (train_metrics )
214- writer .write_scalars (eff_step , aggregate_metrics (all_train_metrics ))
210+ if eff_step % train_config .checkpoint_interval == 0 :
211+ ckpt_train_state = unreplicate (train_state )
212+ checkpoint_manager .save (
213+ eff_step ,
214+ items = dict (train_state = jax .tree .map (np .array , ckpt_train_state )),
215+ )
215216
216217 if eff_step % train_config .validation_interval == 0 :
217218 all_valid_metrics = [
@@ -220,10 +221,10 @@ def train(
220221 ]
221222 writer .write_scalars (eff_step , aggregate_metrics (all_valid_metrics ))
222223
223- if eff_step % train_config . checkpoint_interval == 0 :
224- ckpt_train_state = unreplicate ( train_state )
225- checkpoint_manager . save (
226- eff_step ,
227- items = dict ( train_state = jax . tree . map ( np . array , ckpt_train_state )),
228- )
224+ all_train_metrics = []
225+ for _ in range ( grad_accum_steps ):
226+ train_state , train_metrics = p_train_step ( train_state , next ( train_it ))
227+ all_train_metrics . append ( train_metrics )
228+ writer . write_scalars ( eff_step , aggregate_metrics ( all_train_metrics ))
229+
229230 eff_step = int (unreplicate (train_state .step )) // grad_accum_steps
0 commit comments