diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index f1b654e0e..6cb5c2e81 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -21,6 +21,9 @@ class TrainingConfig(BaseModel): device: str = "cpu" + automatic_mixed_precision: bool = ( + True # Enables automatic mixed precision which allows float32 gradients while working with lower precision. This only has effect when dtype is not float32 + ) dtype: str = "float32" duration: TimeValue = Iteration(1) # TimeValue(number=1, unit=TimeUnit.ITERATION) seed: int = 0 diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 392779a24..36d0c9cb6 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -5,8 +5,9 @@ import torch from loguru import logger -from torch import Tensor, device as Device, dtype as DType, nn +from torch import Tensor, device as Device, dtype as DType, float16, float32, nn from torch.autograd import backward +from torch.cuda.amp import GradScaler, autocast from torch.optim import Optimizer from torch.optim.lr_scheduler import ( CosineAnnealingLR, @@ -100,11 +101,13 @@ def decorator(func: Callable[[Any, ModelConfigT], ModuleT]) -> ModuleT: def wrapper(self: Trainer[BaseConfig, Any], config: ModelConfigT) -> fl.Module: name = func.__name__ model = func(self, config) - model = model.to(self.device, dtype=self.dtype) + model = model.to(self.device) if config.requires_grad is not None: logger.info(f"Setting requires_grad to {config.requires_grad} for model: {name}") model.requires_grad_(requires_grad=config.requires_grad) learnable_parameters = [param for param in model.parameters() if param.requires_grad] + if not self.config.training.automatic_mixed_precision: + model.to(dtype=self.dtype) numel = sum(param.numel() for param in learnable_parameters) logger.info(f"Number of learnable parameters in {name}: {human_readable_number(numel)}") self.models[name] = ModelItem( @@ -181,6 +184,12 @@ def dtype(self) -> DType: logger.info(f"Using dtype: {dtype}") return dtype + @cached_property + def scaler(self) -> GradScaler | None: + if self.dtype != float16 or not self.config.training.automatic_mixed_precision: + return None + return GradScaler() + @property def learnable_parameters(self) -> list[nn.Parameter]: """Returns a list of learnable parameters in all models""" @@ -348,17 +357,32 @@ def compute_loss(self, batch: Batch) -> Tensor: ... def compute_evaluation(self) -> None: pass + def backward_step(self, scaled_loss: Tensor) -> None: + if self.scaler is None: + backward(tensors=scaled_loss) + return + self.scaler.scale(scaled_loss).backward() # type: ignore + + def optimizer_step(self) -> None: + if self.scaler is not None: + self.scaler.unscale_(self.optimizer) + max_norm = self.config.training.gradient_clipping_max_norm or float("inf") + self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item() + if self.scaler is None: + self.optimizer.step() + return + self.scaler.step(self.optimizer) # type: ignore + self.scaler.update() # + def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") scaled_loss = self.loss / self.clock.num_step_per_iteration - backward(tensors=scaled_loss) + self.backward_step(scaled_loss) self._call_callbacks(event_name="on_backward_end") if self.clock.is_optimizer_step: self._call_callbacks(event_name="on_optimizer_step_begin") - max_norm = self.config.training.gradient_clipping_max_norm or float("inf") - self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(self.learnable_parameters, max_norm=max_norm).item() - self.optimizer.step() + self.optimizer_step() self.optimizer.zero_grad() self._call_callbacks(event_name="on_optimizer_step_end") if self.clock.is_due(self.config.lr_scheduler.update_interval): @@ -371,7 +395,8 @@ def backward(self) -> None: def step(self, batch: Batch) -> None: """Perform a single training step.""" self._call_callbacks(event_name="on_compute_loss_begin") - loss = self.compute_loss(batch=batch) + with autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision): + loss = self.compute_loss(batch=batch) self.loss = loss self._call_callbacks(event_name="on_compute_loss_end") self.backward() @@ -412,7 +437,8 @@ def evaluate(self) -> None: """Evaluate the model.""" self.set_models_to_mode(mode="eval") self._call_callbacks(event_name="on_evaluate_begin") - self.compute_evaluation() + with autocast(dtype=self.dtype, enabled=self.config.training.automatic_mixed_precision): + self.compute_evaluation() self._call_callbacks(event_name="on_evaluate_end") self.set_models_to_mode(mode="train")