-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathutils.py
90 lines (82 loc) · 3.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int:
"""Hash a string to an integer."""
modulus = 10**9 + 7 # Large prime modulus
hash_int = 0
for char in s:
hash_int = (hash_int * 31 + ord(char)) % modulus
return hash_int
def get_pipeline(
model_name: str,
precision: str,
use_qencoder: bool = False,
lora_name: str = "None",
lora_weight: float = 1,
device: str | torch.device = "cuda",
pipeline_init_kwargs: dict = {},
) -> FluxPipeline:
if model_name == "schnell":
if precision in ["int4", "fp4"]:
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
else:
assert precision == "fp4"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-fp4-flux.1-schnell", precision="fp4"
)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
else:
assert precision == "bf16"
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
elif model_name == "dev":
if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
if lora_name not in ["All", "None"]:
transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
transformer.set_lora_strength(lora_weight)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
else:
assert precision == "bf16"
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
if lora_name == "All":
# Pre-load all the LoRA weights for demo use
for name, path in LORA_PATHS.items():
pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
for m in pipeline.transformer.modules():
if isinstance(m, lora.LoraLayer):
m.set_adapter(m.scaling.keys())
for name in m.scaling.keys():
m.scaling[name] = 0
elif lora_name != "None":
path = LORA_PATHS[lora_name]
pipeline.load_lora_weights(
path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
)
for m in pipeline.transformer.modules():
if isinstance(m, lora.LoraLayer):
for name in m.scaling.keys():
m.scaling[name] = lora_weight
else:
raise NotImplementedError(f"Model {model_name} not implemented")
pipeline = pipeline.to(device)
return pipeline