Skip to content

Commit

Permalink
train_universal enabled for partial model_state loading
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Jan 25, 2025
1 parent 932b8ba commit 935305b
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions applications/train_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from credit.metrics import LatWeightedMetrics
from credit.pbs import launch_script, launch_script_mpi
from credit.models import load_model
from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO
from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO, load_state_dict_error_handler


warnings.filterwarnings("ignore")
Expand All @@ -47,12 +47,12 @@
# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


# TODO: enable partial loading and weight freezing
def load_model_states_and_optimizer(conf, model, device):
"""
Load the model states, optimizer, scheduler, and gradient scaler.
Args:
Args:jo
conf (dict): Configuration dictionary containing training parameters.
model (torch.nn.Module): The model to be trained.
device (torch.device): The device (CPU or GPU) where the model is located.
Expand Down Expand Up @@ -94,7 +94,7 @@ def load_model_states_and_optimizer(conf, model, device):
# Load an optimizer, gradient scaler, and learning rate scheduler, the optimizer must come after wrapping model using FSDP
if not load_weights: # Loaded after loading model weights when reloading
optimizer = torch.optim.AdamW(
model.parameters(),
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
Expand All @@ -113,7 +113,7 @@ def load_model_states_and_optimizer(conf, model, device):
load_optimizer_conf or load_scaler_conf or load_scheduler_conf
):
optimizer = torch.optim.AdamW(
model.parameters(),
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
Expand All @@ -124,7 +124,7 @@ def load_model_states_and_optimizer(conf, model, device):
f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
optimizer = torch.optim.AdamW(
model.parameters(),
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
Expand All @@ -142,12 +142,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
model.module.load_state_dict(checkpoint["model_state_dict"])
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
model.load_state_dict(checkpoint["model_state_dict"])
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_state_dict_error_handler(load_msg)

# Load the learning rate scheduler and mixed precision grad scaler
scheduler = load_scheduler(optimizer, conf)
scaler = (
Expand All @@ -167,7 +172,7 @@ def load_model_states_and_optimizer(conf, model, device):
f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
optimizer = torch.optim.AdamW(
model.parameters(),
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
Expand All @@ -191,14 +196,19 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
model.module.load_state_dict(checkpoint["model_state_dict"])
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
model.load_state_dict(checkpoint["model_state_dict"])
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_state_dict_error_handler(load_msg)

optimizer = torch.optim.AdamW(
model.parameters(),
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
Expand Down

0 comments on commit 935305b

Please sign in to comment.