Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint has error key #18

Open
xuchen-dev opened this issue Feb 6, 2025 · 3 comments
Open

checkpoint has error key #18

xuchen-dev opened this issue Feb 6, 2025 · 3 comments

Comments

@xuchen-dev
Copy link

Image

Hello, when I loaded the model trained with train_tdd.cpy, it reported an error with the incorrect key. When I manually removed the ’base_madel.model‘ prefix, the predicted image was noisy

@fastisrealslow
Copy link
Collaborator

@WangCunzheng 看起来是key的问题,存正帮忙转换下key

@WangCunzheng
Copy link
Collaborator

WangCunzheng commented Feb 10, 2025 via email

@fastisrealslow
Copy link
Collaborator

import os
import cv2
import numpy as np
import torch
import argparse

from PIL import Image
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from safetensors.torch import load_file, save_file
from accelerate import Accelerator


def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
    kohya_ss_state_dict = {}
    for peft_key, weight in module.items():
        kohya_key = peft_key.replace("unet.base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)
        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f'{kohya_key.split(".")[0]}.alpha'
            kohya_ss_state_dict[alpha_key] = torch.tensor(8).to(dtype)

    return kohya_ss_state_dict


# lora_weight_1 = load_file('TDD_uc0.2_etas0.3_ddim250_adv/checkpoint-20000/unet_lora/pytorch_lora_weights.safetensors')
# lora_state_dict_1 = get_module_kohya_state_dict(lora_weight_1, "lora_unet", torch.float16)
# save_file(lora_state_dict_1, os.path.join('tar_files/tdd_out', "pytorch_lora_weights.safetensors"))

@WangCunzheng @xuchen-dev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants