Skip to content

Commit 7f9c8cb

Browse files
authored
[DeepSpeed] sync gradient accum steps from deepspeed plugin (#3632)
* sync steps * add a debug log when overriding * make grad accum always consistent * remove debug
1 parent 9888c7e commit 7f9c8cb

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

src/accelerate/accelerator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,32 @@ def __init__(
508508
raise ValueError(
509509
"You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object."
510510
)
511+
gradient_accumulation_steps = gradient_accumulation_plugin.num_steps
511512
else:
512513
gradient_accumulation_steps = int(
513514
parse_choice_from_env("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", gradient_accumulation_steps)
514515
)
516+
517+
# If using DeepSpeed, update gradient accumulation steps from the DeepSpeed plugin
518+
if self.state.distributed_type == DistributedType.DEEPSPEED and self.state.deepspeed_plugin is not None:
519+
deepspeed_gradient_accumulation_steps = self.state.deepspeed_plugin.get_value(
520+
"gradient_accumulation_steps"
521+
)
522+
if deepspeed_gradient_accumulation_steps != gradient_accumulation_steps:
523+
if gradient_accumulation_plugin is not None:
524+
logger.warning(
525+
f"Gradient accumulation steps mismatch: GradientAccumulationPlugin has {gradient_accumulation_steps}, "
526+
f"DeepSpeed config has {deepspeed_gradient_accumulation_steps}. Using DeepSpeed's value."
527+
)
528+
gradient_accumulation_steps = deepspeed_gradient_accumulation_steps
529+
530+
# Create or update the gradient accumulation plugin
531+
if gradient_accumulation_plugin is None:
515532
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps)
533+
else:
534+
# Update the plugin's num_steps if it was changed due to DeepSpeed sync
535+
if gradient_accumulation_plugin.num_steps != gradient_accumulation_steps:
536+
gradient_accumulation_plugin.num_steps = gradient_accumulation_steps
516537
self.gradient_state = GradientState(
517538
gradient_accumulation_plugin=gradient_accumulation_plugin,
518539
)

0 commit comments

Comments
 (0)