forked from mit-han-lab/nunchaku
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
71 lines (62 loc) · 2.2 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import time
import uuid
from typing import List
import torch
from cog import BasePredictor, Input, Path
from nunchaku.pipelines import flux as nunchaku_flux
MODEL_ID = "black-forest-labs/FLUX.1-schnell"
MODEL_CACHE = "model-cache"
QUANT_MODEL_PATH = "mit-han-lab/svdq-int4-flux.1-schnell"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading FLUX pipeline...")
self.pipe = nunchaku_flux.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
qmodel_path=QUANT_MODEL_PATH,
# cache_dir=MODEL_CACHE,
# local_files_only=False,
)
self.pipe.enable_sequential_cpu_offload()
@torch.inference_mode()
def predict(
self,
prompt: str = Input(
description="Input prompt",
default="a photo of an astronaut riding a horse on mars",
),
width: int = Input(
description="Width of output image",
default=1024,
),
height: int = Input(
description="Height of output image",
default=1024,
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=10, default=4
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> List[Path]:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
generator = torch.Generator("cuda").manual_seed(seed)
output = self.pipe(
prompt=prompt,
generator=generator,
width=width,
height=height,
num_inference_steps=num_inference_steps,
)
# Create unique filename using timestamp and UUID
timestamp = int(time.time())
random_id = str(uuid.uuid4())[:8]
output_path = f"/tmp/out-{timestamp}-{random_id}.jpg"
output.images[0].save(output_path, format='JPEG', quality=95)
return [Path(output_path)]