-
Notifications
You must be signed in to change notification settings - Fork 362
/
Copy pathflux-demo.py
154 lines (129 loc) · 4.19 KB
/
flux-demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import time
import gradio as gr
import torch
import torch_tensorrt
from diffusers import FluxPipeline
DEVICE = "cuda:0"
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
pipe.to(DEVICE).to(torch.float16)
backbone = pipe.transformer
batch_size = 2
BATCH = torch.export.Dim("batch", min=1, max=8)
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
# To see this recommendation, you can try exporting using min=1, max=4096
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {},
"img_ids": {},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}
settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"enabled_precisions": {torch.float32},
"truncate_double": True,
"min_block_size": 1,
"use_fp32_acc": True,
"use_explicit_typing": True,
"debug": False,
"use_python_runtime": True,
"immutable_weights": False,
"enable_cuda_graph": True,
}
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
pipe.transformer = trt_gm
def generate_image(prompt, inference_step, batch_size=2):
start_time = time.time()
image = pipe(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
end_time = time.time()
return image, end_time - start_time
generate_image(["Test"], 2)
torch.cuda.empty_cache()
def model_change(model):
if model == "Torch Model":
pipe.transformer = backbone
backbone.to(DEVICE)
else:
backbone.to("cpu")
pipe.transformer = trt_gm
torch.cuda.empty_cache()
def load_lora(path):
pipe.load_lora_weights(
path,
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
print("LoRA loaded! Begin refitting")
generate_image(["Test"], 2)
print("Refitting Finished!")
# Create Gradio interface
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")
with gr.Row():
with gr.Column():
# Input components
prompt_input = gr.Textbox(
label="Prompt", placeholder="Enter your prompt here...", lines=3
)
model_dropdown = gr.Dropdown(
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
value="Torch-TensorRT Accelerated Model",
label="Model Variant",
)
lora_upload_path = gr.Textbox(
label="LoRA Path",
placeholder="Enter the LoRA checkpoint path here",
value="/home/TensorRT/examples/apps/NGRVNG.safetensors",
lines=2,
)
num_steps = gr.Slider(
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
)
batch_size = gr.Slider(
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
)
generate_btn = gr.Button("Generate Image")
load_lora_btn = gr.Button("Load LoRA")
with gr.Column():
# Output component
output_image = gr.Gallery(label="Generated Image")
time_taken = gr.Textbox(
label="Generation Time (seconds)", interactive=False
)
# Connect the button to the generation function
model_dropdown.change(model_change, inputs=[model_dropdown])
load_lora_btn.click(
fn=load_lora,
inputs=[
lora_upload_path,
],
)
# Update generate button click to include time output
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
num_steps,
batch_size,
],
outputs=[output_image, time_taken],
)
# Launch the interface
if __name__ == "__main__":
demo.launch()