-
Notifications
You must be signed in to change notification settings - Fork 54
Open
Description
When I attempt to restore statefuldataloader state alongside model state from other training workers using torchft:
train_loader = StatefulDataLoader(
train_data, batch_size=64, num_workers=2, sampler=sampler
)
def load_state_dict(state_dict):
model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optim"])
train_loader.load_state_dict(state_dict["dataloader"])
def state_dict():
return {
"model": model.state_dict(),
"optim": optimizer.state_dict(),
"dataloader": train_loader.state_dict(),
}
and simply stop training when the dataloader runs out of data
for images, labels in train_loader:
# train
What ends up happening is
- Rank 0 and rank 1 both call loss.backward(), which succeeds.
- Rank 1 dies.
- Rank 0 calls optimizer.step . should_commit returns False so step is not incremented.
- Rank 0 advances to the next batch
- Rank 1 recovers. However, it is one batch behind rank 0 - train loader recovery is not happening correctly.
- Rank 0 finishes
- Rank 1 stalls on loss.backward() on bonus batch.
Using async_quorum=False:
manager = Manager(
use_async_quorum=False,
still ran into the same issue. Playing around with drop_last didn't work for me either.
The "terminate after N steps" use case works perfectly, but many users might be interested in terminating after going through the dataset N times.
Metadata
Metadata
Assignees
Labels
No labels