diff --git a/scripts/extract_meta_info_stage1.py b/scripts/extract_meta_info_stage1.py index d25123e1..936cb06c 100644 --- a/scripts/extract_meta_info_stage1.py +++ b/scripts/extract_meta_info_stage1.py @@ -21,6 +21,8 @@ import os from pathlib import Path +import torch + def collect_video_folder_paths(root_path: Path) -> list: """ @@ -52,6 +54,10 @@ def construct_meta_info(frames_dir_path: Path) -> dict: print(f"Mask path not found: {mask_path}") return None + if torch.load(face_emb_path) is None: + print(f"Face emb is None: {face_emb_path}") + return None + return { "image_path": str(frames_dir_path), "mask_path": mask_path, diff --git a/scripts/train_stage1.py b/scripts/train_stage1.py index 9c6265fa..e9e7e847 100644 --- a/scripts/train_stage1.py +++ b/scripts/train_stage1.py @@ -16,6 +16,7 @@ """ import argparse +import copy import logging import math import os @@ -211,6 +212,7 @@ def log_validation( logger.info("Running validation... ") ori_net = accelerator.unwrap_model(net) + ori_net = copy.deepcopy(ori_net) reference_unet = ori_net.reference_unet denoising_unet = ori_net.denoising_unet face_locator = ori_net.face_locator @@ -278,6 +280,7 @@ def log_validation( canvas.save(out_file) del pipe + del ori_net torch.cuda.empty_cache() return pil_images