Skip to content

Commit 52fae09

Browse files
authored
Add end_training/destroy_pg to everything and unpin numpy (#3030)
* Add end_training/destroy_pg to everything * Carry over to AcceleratorState * If forked, ignore * More numpy fun * Skip only init
1 parent 7ffe766 commit 52fae09

40 files changed

+72
-19
lines changed

examples/by_feature/automatic_gradient_accumulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def inner_training_loop(batch_size):
217217
# And call it at the end with no arguments
218218
# Note: You could also refactor this outside of your training loop function
219219
inner_training_loop()
220+
accelerator.end_training()
220221

221222

222223
def main():

examples/by_feature/checkpointing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def training_function(config, args):
276276
if args.output_dir is not None:
277277
output_dir = os.path.join(args.output_dir, output_dir)
278278
accelerator.save_state(output_dir)
279+
accelerator.end_training()
279280

280281

281282
def main():

examples/by_feature/cross_validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def training_function(config, args):
255255
preds = torch.stack(test_predictions, dim=0).sum(dim=0).div(int(args.num_folds)).argmax(dim=-1)
256256
test_metric = metric.compute(predictions=preds, references=test_references)
257257
accelerator.print("Average test metrics from all folds:", test_metric)
258+
accelerator.end_training()
258259

259260

260261
def main():

examples/by_feature/ddp_comm_hook.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def training_function(config, args):
192192
eval_metric = metric.compute()
193193
# Use accelerator.print to print only on the main process.
194194
accelerator.print(f"epoch {epoch}:", eval_metric)
195+
accelerator.end_training()
195196

196197

197198
def main():

examples/by_feature/deepspeed_with_config_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ def group_texts(examples):
716716

717717
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
718718
json.dump({"perplexity": perplexity, "eval_loss": eval_loss.item()}, f)
719+
accelerator.end_training()
719720

720721

721722
if __name__ == "__main__":

examples/by_feature/early_stopping.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def training_function(config, args):
222222

223223
# Use accelerator.print to print only on the main process.
224224
accelerator.print(f"epoch {epoch}:", eval_metric)
225+
accelerator.end_training()
225226

226227

227228
def main():

examples/by_feature/fsdp_with_peak_mem_tracking.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,7 @@ def collate_fn(examples):
399399
step=epoch,
400400
)
401401

402-
if args.with_tracking:
403-
accelerator.end_training()
402+
accelerator.end_training()
404403

405404

406405
def main():

examples/by_feature/gradient_accumulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def training_function(config, args):
197197
eval_metric = metric.compute()
198198
# Use accelerator.print to print only on the main process.
199199
accelerator.print(f"epoch {epoch}:", eval_metric)
200+
accelerator.end_training()
200201

201202

202203
def main():

examples/by_feature/local_sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def training_function(config, args):
202202
eval_metric = metric.compute()
203203
# Use accelerator.print to print only on the main process.
204204
accelerator.print(f"epoch {epoch}:", eval_metric)
205+
accelerator.end_training()
205206

206207

207208
def main():

examples/by_feature/megatron_lm_gpt_pretraining.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ def group_texts(examples):
703703

704704
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
705705
json.dump({"perplexity": perplexity}, f)
706+
accelerator.end_training()
706707

707708

708709
if __name__ == "__main__":

0 commit comments

Comments
 (0)