forked from mit-han-lab/nunchaku
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.py
30 lines (25 loc) · 1.05 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from diffusers import FluxPipeline
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
# List of prompts
prompts = [
"A cat holding a sign that says hello world",
"A dog playing with a ball in the park",
"A beautiful sunset over the mountains",
"A futuristic cityscape at night",
"A group of people having a picnic in the park"
]
# Generate images and calculate average time per image
total_time = 0
for i, prompt in enumerate(prompts):
start_time = time.time()
image = pipeline(prompt, num_inference_steps=4, guidance_scale=0).images[0]
end_time = time.time()
image.save(f"example_{i}.png")
total_time += end_time - start_time
average_time_per_image = total_time / len(prompts)
print(f"Average time per image: {average_time_per_image:.2f} seconds")