From 50b846bf2dddef9c513044151fdbd6412c8c1aa2 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 22 Jan 2025 16:37:45 -0800 Subject: [PATCH] Import error in manager.py + switch to sync mode --- torchft/manager.py | 1 + train_fsdp.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index 54132a5c..4c06ee45 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch +import torch.distributed as dist from torch.distributed import ReduceOp, TCPStore from torchft.checkpointing import CheckpointServer diff --git a/train_fsdp.py b/train_fsdp.py index 22302d7a..7a5d07cd 100644 --- a/train_fsdp.py +++ b/train_fsdp.py @@ -126,6 +126,7 @@ def state_dict(): load_state_dict=load_state_dict, state_dict=state_dict, replica_id=f"train_fsdp_{REPLICA_GROUP_ID}", + use_async_quorum=False, ) mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager) @@ -136,8 +137,6 @@ def state_dict(): optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5)) - optimizer.zero_grad() - while manager.current_step() < 500: model.train() for batch in tqdm(train_dataloader):