From c62c6308e995517e9abbc8b4c02c4932d673d6e6 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Tue, 7 May 2024 13:39:38 -0700 Subject: [PATCH] multidataloader - ensure state is restored for round robin (#823) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/823 Add stateful API to `MultiIterator` and implement for `RoundRobinIterator` Reviewed By: diego-urgell, JKSenthil Differential Revision: D57012694 fbshipit-source-id: bc0ef26ede428bcc20009757c240c4968e89f029 --- tests/utils/data/test_multi_dataloader.py | 88 +++++++++++++++++++++++ torchtnt/utils/data/iterators.py | 32 ++++++++- torchtnt/utils/data/multi_dataloader.py | 24 ++++++- 3 files changed, 142 insertions(+), 2 deletions(-) diff --git a/tests/utils/data/test_multi_dataloader.py b/tests/utils/data/test_multi_dataloader.py index c0a1820dae..6543e5c762 100644 --- a/tests/utils/data/test_multi_dataloader.py +++ b/tests/utils/data/test_multi_dataloader.py @@ -14,6 +14,7 @@ import torch from torch.utils.data import DataLoader, Dataset +from torchtnt.framework._test_utils import generate_random_dataloader from torchtnt.utils.data.iterators import ( AllDatasetBatches, @@ -22,6 +23,7 @@ MultiIterator, RandomizedBatchSampler, RoundRobin, + RoundRobinIterator, StoppingMechanism, ) from torchtnt.utils.data.multi_dataloader import MultiDataLoader @@ -488,6 +490,92 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 0, ) + def test_multi_dataloader_state_dict_with_iterator_state(self) -> None: + dataloader_1 = generate_random_dataloader( + num_samples=8, input_dim=1, batch_size=8 + ) + dataloader_2 = generate_random_dataloader( + num_samples=16, input_dim=1, batch_size=8 + ) + multi_dataloader = MultiDataLoader( + self._get_dataloaders_dict(dataloader_1, dataloader_2), + RoundRobin(), + ) + + multi_dl_state_dict = multi_dataloader.state_dict() + # before creating the iterator we don't expect the iterator_state to be present in the dl state dict + self.assertFalse("iterator_state" in multi_dl_state_dict) + + multi_dl_iter = iter(multi_dataloader) + multi_dl_state_dict = multi_dataloader.state_dict() + self.assertTrue("iterator_state" in multi_dl_state_dict) + self.assertEqual( + multi_dl_state_dict["iterator_state"], + {"cur_dataloader": "1", "finished_dataloaders": []}, + ) + next(multi_dl_iter) # should return batch from 1 + next(multi_dl_iter) # should return batch from 2 + next( + multi_dl_iter + ) # should return batch from 2 after raising StopIteration from the first dl + multi_dl_state_dict = multi_dataloader.state_dict() + self.assertTrue("iterator_state" in multi_dl_state_dict) + self.assertEqual( + multi_dl_state_dict["iterator_state"], + {"cur_dataloader": "2", "finished_dataloaders": ["1"]}, + ) + + # create fresh dl and load state dict. assert that the initial values are updated. + multi_dataloader_2 = MultiDataLoader( + self._get_dataloaders_dict(dataloader_1, dataloader_2), + RoundRobin(), + ) + multi_dataloader_2.load_state_dict(multi_dl_state_dict) + round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2)) + self.assertEqual(round_robin_iter.cur_dataloader, "2") + self.assertEqual(round_robin_iter.finished_dataloaders, ["1"]) + + # verify that after calling iter() again, values are reset + round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2)) + self.assertEqual(round_robin_iter.cur_dataloader, "1") + self.assertEqual(round_robin_iter.finished_dataloaders, []) + + def test_invalid_load_state_dict(self) -> None: + dataloader_1 = generate_random_dataloader( + num_samples=8, input_dim=1, batch_size=8 + ) + dataloader_2 = generate_random_dataloader( + num_samples=16, input_dim=1, batch_size=8 + ) + multi_dataloader = MultiDataLoader( + self._get_dataloaders_dict(dataloader_1, dataloader_2), + RoundRobin(), + ) + + # invalid state dict - finished dataloaders and curr dataloader do not exist + multi_dataloader.load_state_dict( + {"finished_dataloaders": ["3"], "cur_dataloader": "4"} + ) + round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader)) + # ensure the iterator state is not changed + self.assertEqual(round_robin_iter.cur_dataloader, "1") + self.assertEqual(round_robin_iter.finished_dataloaders, []) + + def test_state_dict_with_non_stateful_iterator(self) -> None: + dataloader_1 = generate_random_dataloader( + num_samples=8, input_dim=1, batch_size=8 + ) + dataloader_2 = generate_random_dataloader( + num_samples=16, input_dim=1, batch_size=8 + ) + multi_dataloader = MultiDataLoader( + self._get_dataloaders_dict(dataloader_1, dataloader_2), + DataIterationStrategy(), + CustomRandomIterator, + ) + iter(multi_dataloader) + self.assertFalse("iterator_state" in multi_dataloader.state_dict()) + def _get_dataloaders_dict( self, first_dataloader: DataLoader, second_dataloader: DataLoader ) -> Dict[str, Union[DataLoader, Iterable[object]]]: diff --git a/torchtnt/utils/data/iterators.py b/torchtnt/utils/data/iterators.py index 2673eef972..bd761dfb4a 100644 --- a/torchtnt/utils/data/iterators.py +++ b/torchtnt/utils/data/iterators.py @@ -9,6 +9,8 @@ from __future__ import annotations +import logging + import random from abc import abstractmethod from dataclasses import dataclass @@ -32,11 +34,13 @@ import torch import torch.distributed as dist - if TYPE_CHECKING: from torch.utils.data import DataLoader +logger: logging.Logger = logging.getLogger(__name__) + + @dataclass class DataIterationStrategy: pass @@ -75,6 +79,12 @@ def __str__(self) -> str: def __next__(self) -> Dict[str, Any]: pass + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass + class StoppingMechanism(Enum): ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED" @@ -176,6 +186,26 @@ def __next__(self) -> Dict[str, Any]: return self.__next__() + def state_dict(self) -> Dict[str, Any]: + return { + "finished_dataloaders": self.finished_dataloaders, + "cur_dataloader": self.cur_dataloader, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + logger.info( + f"Loading RoundRobinIterator state. Finished dataloaders: {state_dict['finished_dataloaders']} and trying to set cur_dataloader to {self.cur_dataloader}" + ) + self.finished_dataloaders = state_dict["finished_dataloaders"] + cur_dataloader = state_dict["cur_dataloader"] + if cur_dataloader not in self.dataloader_cycle: + logger.warning( + f"Did not find {cur_dataloader} in {list(self.dataloader_cycle)}. Skipping setting cur_dataloader" + ) + return + while self.cur_dataloader != cur_dataloader: + self.cur_dataloader = next(self.dataloader_cycle) + @dataclass class AllDatasetBatches(DataIterationStrategy): diff --git a/torchtnt/utils/data/multi_dataloader.py b/torchtnt/utils/data/multi_dataloader.py index 8acb2e5f21..8f517ae25a 100644 --- a/torchtnt/utils/data/multi_dataloader.py +++ b/torchtnt/utils/data/multi_dataloader.py @@ -12,6 +12,8 @@ import logging from typing import Any, Dict, Iterable, Iterator, Optional, Type, TYPE_CHECKING, Union +from pyre_extensions import none_throws + from torchtnt.utils.data.iterators import ( DataIterationStrategy, DataIterationStrategyRegistry, @@ -19,6 +21,7 @@ ) from torchtnt.utils.stateful import Stateful + if TYPE_CHECKING: from torch.utils.data import DataLoader @@ -53,6 +56,7 @@ def __init__( self.individual_dataloaders = individual_dataloaders self.iteration_strategy = iteration_strategy self.iterator_cls = iterator_cls + self.current_iterator: Optional[MultiIterator] = None for name in list(individual_dataloaders.keys()): try: next(iter(self.individual_dataloaders[name])) @@ -64,6 +68,7 @@ def __init__( f"Dataloader '{name}' which contains no data. " "You might have empty dataloaders in the input dict." ) + self.iterator_state: Optional[Dict[str, Any]] = None def __iter__(self) -> Iterator[Dict[str, Any]]: """Iterator functions for the collection of dataloaders. @@ -77,10 +82,15 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: iterator_cls = DataIterationStrategyRegistry.get(self.iteration_strategy) # in practice, DataIterationStrategyRegistry.get() returns just concrete classes # pyre-ignore[45]: Cannot instantiate abstract class `MultiIterator`. - return iterator_cls( + self.current_iterator = iterator_cls( individual_dataloaders=self.individual_dataloaders, iteration_strategy=self.iteration_strategy, ) + if self.iterator_state is not None: + self.current_iterator.load_state_dict(self.iterator_state) + + self.iterator_state = None + return none_throws(self.current_iterator) def state_dict(self) -> Dict[str, Any]: """Return an aggregated state dict based on individual dataloaders. @@ -95,6 +105,14 @@ def state_dict(self) -> Dict[str, Any]: if isinstance(dl, Stateful): state_dict[name] = dl.state_dict() + if (current_iterator := self.current_iterator) is not None: + iterator_state = current_iterator.state_dict() + if iterator_state: + logger.info("Storing iterator state in MultiDataLoader state_dict") + # we make an implicit assumption here that none of the dataloaders have the "iterator_state" key in order to be backwards compatible + # with already saved checkpoints (we don't want to modify the dataloaders stateful names) + state_dict["iterator_state"] = iterator_state + return state_dict def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -114,3 +132,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ) continue dl.load_state_dict(contents) + + if "iterator_state" in state_dict: + # this will be used during the next __iter__ call + self.iterator_state = state_dict["iterator_state"]