Skip to content

Commit

Permalink
Update checkpointing to record rollout state
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Jan 31, 2025
1 parent 2c4cf1b commit 8d8fdfd
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def description(self) -> str:
return (
f"Randomly select a rollout from the increasing range "
f"{range(self._minimum, self._maximum, self._step)}"
f"with the upper bound increasing by {self._step} every {self._every_n} {self._step_type}"
f"with the upper bound increasing by {self._step} every {self._every_n} {self._step_type!s}/s"
)


Expand Down
2 changes: 1 addition & 1 deletion training/src/anemoi/training/schedulers/rollout/stepped.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def maximum_rollout(self) -> int:
def description(self) -> str:
return (
"Stepped rollout scheduler stepping between "
f"{self._minimum} and {self._maximum} by {self._increment} for every {self._every_n} {self._step_type}/s."
f"{self._minimum} and {self._maximum} by {self._increment} for every {self._every_n} {self._step_type!s}/s."
)


Expand Down
3 changes: 2 additions & 1 deletion training/src/anemoi/training/schedulers/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import enum
from abc import ABC
from typing import Literal
from typing import Self

from typing_extensions import Self

from anemoi.training.schedulers.utils import get_closest_key
from anemoi.training.schedulers.utils import get_value_from_closest_key
Expand Down
13 changes: 9 additions & 4 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,18 +606,23 @@ def calculate_val_metrics(
return metrics

def on_load_checkpoint(self, checkpoint: dict) -> None:
_ = checkpoint
self.rollout = instantiate(self.config.training.rollout).get_state_from(self.rollout)
if "rollout" in checkpoint:
self.rollout = instantiate(self.config.training.rollout).get_state_from(checkpoint["rollout"])

def on_save_checkpoint(self, checkpoint: dict) -> None:
checkpoint["rollout"] = self.rollout

def on_train_batch_end(self, *_) -> None:
self.rollout.step()

def on_train_epoch_end(self, *_) -> None:
if self.trainer.limit_val_batches != 0:
if self.trainer.limit_val_batches == 0:
LOGGER.debug("Stepping Rollout on train epoch end")
self.rollout.step_epoch()

def on_validation_batch_end(self, *_) -> None:
def on_validation_epoch_end(self, *_) -> None:
if not self.trainer.sanity_checking:
LOGGER.debug("Stepping Rollout on validation epoch end")
self.rollout.step_epoch()

def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
Expand Down

0 comments on commit 8d8fdfd

Please sign in to comment.