Skip to content

Commit

Permalink
add resume from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
neph1 committed Jan 7, 2025
1 parent 22fbe95 commit 43bb7fd
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion config/config_categories.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
1 change: 1 addition & 0 deletions config/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ precompute_conditions: false
pretrained_model_name_or_path: ''
rank: 128
report_to: none
resume_from_checkpoint: ''
seed: 42
target_modules: to_q to_k to_v to_out.0
text_encoder_dtype: [bf16, fp16, fp32]
Expand Down
5 changes: 4 additions & 1 deletion run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
--checkpointing_steps {config.get('checkpointing_steps')} \
--checkpointing_limit {config.get('checkpointing_limit')} \
{'--enable_slicing' if config.get('enable_slicing') else ''} \
{'--enable_tiling' if config.get('enable_tiling') else ''}"
{'--enable_tiling' if config.get('enable_tiling') else ''} "

if config.get('resume_from_checkpoint'):
training_cmd += f"--resume_from_checkpoint {config.get('resume_from_checkpoint')}"

# Optimizer arguments
optimizer_cmd = f"--optimizer {config.get('optimizer')} \
Expand Down
6 changes: 3 additions & 3 deletions scripts/rename_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
def rename_keys(file, outfile: str)-> bool:
sd, metadata = load_state_dict(file, torch.float32)

keys_to_normalize = [key for key in sd.keys()]
values_to_normalize = [sd[key].to(torch.float32) for key in keys_to_normalize]
keys_to_rename = [key for key in sd.keys()]
values = [sd[key].to(torch.float32) for key in keys_to_rename]
new_sd = dict()
for key, value in zip(keys_to_normalize, values_to_normalize):
for key, value in zip(keys_to_rename, values):
new_sd[key.replace("transformer.", "")] = value

save_to_file(outfile, new_sd, torch.float16, metadata)
Expand Down
19 changes: 19 additions & 0 deletions test/test_trainer_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest
from unittest.mock import patch

import yaml

from trainer_config_validator import TrainerValidator

@pytest.fixture
Expand Down Expand Up @@ -55,6 +57,23 @@ def test_valid_config(valid_config):
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
trainer_validator.validate()

def test_config_template():
config = None
with open('config/config_template.yaml', "r") as file:
config = yaml.safe_load(file)
config['path_to_finetrainers'] = '/path/to/finetrainers'
config['data_root'] = '/path/to/data'
config['pretrained_model_name_or_path'] = 'pretrained_model'

trainer_validator = TrainerValidator(config)
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
trainer_validator.validate()

def test_validate_data_root_not_set(trainer_validator):
trainer_validator.config['data_root'] = ''
with pytest.raises(ValueError, match="data_root is required"):
trainer_validator.validate()

def test_validate_data_root_invalid(trainer_validator):
trainer_validator.config['data_root'] = '/invalid/path'
with pytest.raises(ValueError, match="data_root path /invalid/path does not exist"):
Expand Down

0 comments on commit 43bb7fd

Please sign in to comment.