Skip to content

Commit

Permalink
multidataloader - ensure state is restored for round robin (#823)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #823

Add stateful API to `MultiIterator` and implement for `RoundRobinIterator`

Reviewed By: diego-urgell, JKSenthil

Differential Revision: D57012694

fbshipit-source-id: bc0ef26ede428bcc20009757c240c4968e89f029
  • Loading branch information
galrotem authored and facebook-github-bot committed May 7, 2024
1 parent ad59364 commit c62c630
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 2 deletions.
88 changes: 88 additions & 0 deletions tests/utils/data/test_multi_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from torch.utils.data import DataLoader, Dataset
from torchtnt.framework._test_utils import generate_random_dataloader

from torchtnt.utils.data.iterators import (
AllDatasetBatches,
Expand All @@ -22,6 +23,7 @@
MultiIterator,
RandomizedBatchSampler,
RoundRobin,
RoundRobinIterator,
StoppingMechanism,
)
from torchtnt.utils.data.multi_dataloader import MultiDataLoader
Expand Down Expand Up @@ -488,6 +490,92 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
0,
)

def test_multi_dataloader_state_dict_with_iterator_state(self) -> None:
dataloader_1 = generate_random_dataloader(
num_samples=8, input_dim=1, batch_size=8
)
dataloader_2 = generate_random_dataloader(
num_samples=16, input_dim=1, batch_size=8
)
multi_dataloader = MultiDataLoader(
self._get_dataloaders_dict(dataloader_1, dataloader_2),
RoundRobin(),
)

multi_dl_state_dict = multi_dataloader.state_dict()
# before creating the iterator we don't expect the iterator_state to be present in the dl state dict
self.assertFalse("iterator_state" in multi_dl_state_dict)

multi_dl_iter = iter(multi_dataloader)
multi_dl_state_dict = multi_dataloader.state_dict()
self.assertTrue("iterator_state" in multi_dl_state_dict)
self.assertEqual(
multi_dl_state_dict["iterator_state"],
{"cur_dataloader": "1", "finished_dataloaders": []},
)
next(multi_dl_iter) # should return batch from 1
next(multi_dl_iter) # should return batch from 2
next(
multi_dl_iter
) # should return batch from 2 after raising StopIteration from the first dl
multi_dl_state_dict = multi_dataloader.state_dict()
self.assertTrue("iterator_state" in multi_dl_state_dict)
self.assertEqual(
multi_dl_state_dict["iterator_state"],
{"cur_dataloader": "2", "finished_dataloaders": ["1"]},
)

# create fresh dl and load state dict. assert that the initial values are updated.
multi_dataloader_2 = MultiDataLoader(
self._get_dataloaders_dict(dataloader_1, dataloader_2),
RoundRobin(),
)
multi_dataloader_2.load_state_dict(multi_dl_state_dict)
round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2))
self.assertEqual(round_robin_iter.cur_dataloader, "2")
self.assertEqual(round_robin_iter.finished_dataloaders, ["1"])

# verify that after calling iter() again, values are reset
round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader_2))
self.assertEqual(round_robin_iter.cur_dataloader, "1")
self.assertEqual(round_robin_iter.finished_dataloaders, [])

def test_invalid_load_state_dict(self) -> None:
dataloader_1 = generate_random_dataloader(
num_samples=8, input_dim=1, batch_size=8
)
dataloader_2 = generate_random_dataloader(
num_samples=16, input_dim=1, batch_size=8
)
multi_dataloader = MultiDataLoader(
self._get_dataloaders_dict(dataloader_1, dataloader_2),
RoundRobin(),
)

# invalid state dict - finished dataloaders and curr dataloader do not exist
multi_dataloader.load_state_dict(
{"finished_dataloaders": ["3"], "cur_dataloader": "4"}
)
round_robin_iter = cast(RoundRobinIterator, iter(multi_dataloader))
# ensure the iterator state is not changed
self.assertEqual(round_robin_iter.cur_dataloader, "1")
self.assertEqual(round_robin_iter.finished_dataloaders, [])

def test_state_dict_with_non_stateful_iterator(self) -> None:
dataloader_1 = generate_random_dataloader(
num_samples=8, input_dim=1, batch_size=8
)
dataloader_2 = generate_random_dataloader(
num_samples=16, input_dim=1, batch_size=8
)
multi_dataloader = MultiDataLoader(
self._get_dataloaders_dict(dataloader_1, dataloader_2),
DataIterationStrategy(),
CustomRandomIterator,
)
iter(multi_dataloader)
self.assertFalse("iterator_state" in multi_dataloader.state_dict())

def _get_dataloaders_dict(
self, first_dataloader: DataLoader, second_dataloader: DataLoader
) -> Dict[str, Union[DataLoader, Iterable[object]]]:
Expand Down
32 changes: 31 additions & 1 deletion torchtnt/utils/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from __future__ import annotations

import logging

import random
from abc import abstractmethod
from dataclasses import dataclass
Expand All @@ -32,11 +34,13 @@
import torch
import torch.distributed as dist


if TYPE_CHECKING:
from torch.utils.data import DataLoader


logger: logging.Logger = logging.getLogger(__name__)


@dataclass
class DataIterationStrategy:
pass
Expand Down Expand Up @@ -75,6 +79,12 @@ def __str__(self) -> str:
def __next__(self) -> Dict[str, Any]:
pass

def state_dict(self) -> Dict[str, Any]:
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass


class StoppingMechanism(Enum):
ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED"
Expand Down Expand Up @@ -176,6 +186,26 @@ def __next__(self) -> Dict[str, Any]:

return self.__next__()

def state_dict(self) -> Dict[str, Any]:
return {
"finished_dataloaders": self.finished_dataloaders,
"cur_dataloader": self.cur_dataloader,
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
logger.info(
f"Loading RoundRobinIterator state. Finished dataloaders: {state_dict['finished_dataloaders']} and trying to set cur_dataloader to {self.cur_dataloader}"
)
self.finished_dataloaders = state_dict["finished_dataloaders"]
cur_dataloader = state_dict["cur_dataloader"]
if cur_dataloader not in self.dataloader_cycle:
logger.warning(
f"Did not find {cur_dataloader} in {list(self.dataloader_cycle)}. Skipping setting cur_dataloader"
)
return
while self.cur_dataloader != cur_dataloader:
self.cur_dataloader = next(self.dataloader_cycle)


@dataclass
class AllDatasetBatches(DataIterationStrategy):
Expand Down
24 changes: 23 additions & 1 deletion torchtnt/utils/data/multi_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
import logging
from typing import Any, Dict, Iterable, Iterator, Optional, Type, TYPE_CHECKING, Union

from pyre_extensions import none_throws

from torchtnt.utils.data.iterators import (
DataIterationStrategy,
DataIterationStrategyRegistry,
MultiIterator,
)
from torchtnt.utils.stateful import Stateful


if TYPE_CHECKING:
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -53,6 +56,7 @@ def __init__(
self.individual_dataloaders = individual_dataloaders
self.iteration_strategy = iteration_strategy
self.iterator_cls = iterator_cls
self.current_iterator: Optional[MultiIterator] = None
for name in list(individual_dataloaders.keys()):
try:
next(iter(self.individual_dataloaders[name]))
Expand All @@ -64,6 +68,7 @@ def __init__(
f"Dataloader '{name}' which contains no data. "
"You might have empty dataloaders in the input dict."
)
self.iterator_state: Optional[Dict[str, Any]] = None

def __iter__(self) -> Iterator[Dict[str, Any]]:
"""Iterator functions for the collection of dataloaders.
Expand All @@ -77,10 +82,15 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
iterator_cls = DataIterationStrategyRegistry.get(self.iteration_strategy)
# in practice, DataIterationStrategyRegistry.get() returns just concrete classes
# pyre-ignore[45]: Cannot instantiate abstract class `MultiIterator`.
return iterator_cls(
self.current_iterator = iterator_cls(
individual_dataloaders=self.individual_dataloaders,
iteration_strategy=self.iteration_strategy,
)
if self.iterator_state is not None:
self.current_iterator.load_state_dict(self.iterator_state)

self.iterator_state = None
return none_throws(self.current_iterator)

def state_dict(self) -> Dict[str, Any]:
"""Return an aggregated state dict based on individual dataloaders.
Expand All @@ -95,6 +105,14 @@ def state_dict(self) -> Dict[str, Any]:
if isinstance(dl, Stateful):
state_dict[name] = dl.state_dict()

if (current_iterator := self.current_iterator) is not None:
iterator_state = current_iterator.state_dict()
if iterator_state:
logger.info("Storing iterator state in MultiDataLoader state_dict")
# we make an implicit assumption here that none of the dataloaders have the "iterator_state" key in order to be backwards compatible
# with already saved checkpoints (we don't want to modify the dataloaders stateful names)
state_dict["iterator_state"] = iterator_state

return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -114,3 +132,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
)
continue
dl.load_state_dict(contents)

if "iterator_state" in state_dict:
# this will be used during the next __iter__ call
self.iterator_state = state_dict["iterator_state"]

0 comments on commit c62c630

Please sign in to comment.