Skip to content

torchft should support transferring statefuldataloader state from other workers #309

@TimothySeah

Description

@TimothySeah

When I attempt to restore statefuldataloader state alongside model state from other training workers using torchft:

train_loader = StatefulDataLoader(
    train_data, batch_size=64, num_workers=2, sampler=sampler
)

def load_state_dict(state_dict):
    model.load_state_dict(state_dict["model"])
    optimizer.load_state_dict(state_dict["optim"])
    train_loader.load_state_dict(state_dict["dataloader"])

def state_dict():
    return {
        "model": model.state_dict(),
        "optim": optimizer.state_dict(),
        "dataloader": train_loader.state_dict(),
    }

and simply stop training when the dataloader runs out of data

for images, labels in train_loader:
    # train

What ends up happening is

  1. Rank 0 and rank 1 both call loss.backward(), which succeeds.
  2. Rank 1 dies.
  3. Rank 0 calls optimizer.step . should_commit returns False so step is not incremented.
  4. Rank 0 advances to the next batch
  5. Rank 1 recovers. However, it is one batch behind rank 0 - train loader recovery is not happening correctly.
  6. Rank 0 finishes
  7. Rank 1 stalls on loss.backward() on bonus batch.

Using async_quorum=False:

manager = Manager(
        use_async_quorum=False,

still ran into the same issue. Playing around with drop_last didn't work for me either.

The "terminate after N steps" use case works perfectly, but many users might be interested in terminating after going through the dataset N times.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions