Skip to content

Commit 6fdb9bb

Browse files
committed
Force clip sdpa decomposition
Fix turbine import. Revert change to prompt encoder filename. Update sdxl_prompt_encoder.py
1 parent b4ed496 commit 6fdb9bb

File tree

1 file changed

+136
-127
lines changed

1 file changed

+136
-127
lines changed

models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py

Lines changed: 136 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -48,111 +48,106 @@ def __init__(
4848
def forward(
4949
self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2
5050
):
51-
with torch.no_grad():
52-
prompt_embeds_1 = self.text_encoder_model_1(
53-
text_input_ids_1,
54-
output_hidden_states=True,
55-
)
56-
prompt_embeds_2 = self.text_encoder_model_2(
57-
text_input_ids_2,
58-
output_hidden_states=True,
59-
)
60-
neg_prompt_embeds_1 = self.text_encoder_model_1(
61-
uncond_input_ids_1,
62-
output_hidden_states=True,
63-
)
64-
neg_prompt_embeds_2 = self.text_encoder_model_2(
65-
uncond_input_ids_2,
66-
output_hidden_states=True,
67-
)
68-
# We are only ALWAYS interested in the pooled output of the final text encoder
69-
pooled_prompt_embeds = prompt_embeds_2[0]
70-
neg_pooled_prompt_embeds = neg_prompt_embeds_2[0]
71-
72-
prompt_embeds_list = [
73-
prompt_embeds_1.hidden_states[-2],
74-
prompt_embeds_2.hidden_states[-2],
75-
]
76-
neg_prompt_embeds_list = [
77-
neg_prompt_embeds_1.hidden_states[-2],
78-
neg_prompt_embeds_2.hidden_states[-2],
79-
]
80-
81-
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
82-
neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1)
83-
84-
bs_embed, seq_len, _ = prompt_embeds.shape
85-
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
86-
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
87-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(
88-
bs_embed * 1, -1
89-
)
51+
prompt_embeds_1 = self.text_encoder_model_1(
52+
text_input_ids_1,
53+
output_hidden_states=True,
54+
)
55+
prompt_embeds_2 = self.text_encoder_model_2(
56+
text_input_ids_2,
57+
output_hidden_states=True,
58+
)
59+
neg_prompt_embeds_1 = self.text_encoder_model_1(
60+
uncond_input_ids_1,
61+
output_hidden_states=True,
62+
)
63+
neg_prompt_embeds_2 = self.text_encoder_model_2(
64+
uncond_input_ids_2,
65+
output_hidden_states=True,
66+
)
67+
# We are only ALWAYS interested in the pooled output of the final text encoder
68+
pooled_prompt_embeds = prompt_embeds_2[0]
69+
neg_pooled_prompt_embeds = neg_prompt_embeds_2[0]
70+
71+
prompt_embeds_list = [
72+
prompt_embeds_1.hidden_states[-2],
73+
prompt_embeds_2.hidden_states[-2],
74+
]
75+
neg_prompt_embeds_list = [
76+
neg_prompt_embeds_1.hidden_states[-2],
77+
neg_prompt_embeds_2.hidden_states[-2],
78+
]
79+
80+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
81+
neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1)
82+
83+
bs_embed, seq_len, _ = prompt_embeds.shape
84+
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
85+
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
86+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
87+
if not self.batch_input:
88+
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
89+
add_text_embeds = pooled_prompt_embeds
90+
if not self.batch_input:
91+
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
92+
if self.do_classifier_free_guidance:
9093
if not self.batch_input:
91-
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
92-
add_text_embeds = pooled_prompt_embeds
94+
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view(
95+
1, -1
96+
)
97+
neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1)
98+
neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1)
99+
if not self.batch_input:
100+
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1)
101+
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)
93102
if not self.batch_input:
94-
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
95-
if self.do_classifier_free_guidance:
96-
if not self.batch_input:
97-
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
98-
1, 1
99-
).view(1, -1)
100-
neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1)
101-
neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1)
102-
if not self.batch_input:
103-
neg_prompt_embeds = neg_prompt_embeds.repeat(self.batch_size, 1, 1)
104-
prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0)
105-
if not self.batch_input:
106-
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
107-
self.batch_size, 1
108-
)
109-
add_text_embeds = torch.cat(
110-
[neg_pooled_prompt_embeds, add_text_embeds], dim=0
103+
neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(
104+
self.batch_size, 1
111105
)
112-
add_text_embeds = add_text_embeds.to(self.torch_dtype)
113-
prompt_embeds = prompt_embeds.to(self.torch_dtype)
114-
return prompt_embeds, add_text_embeds
106+
add_text_embeds = torch.cat(
107+
[neg_pooled_prompt_embeds, add_text_embeds], dim=0
108+
)
109+
add_text_embeds = add_text_embeds.to(self.torch_dtype)
110+
prompt_embeds = prompt_embeds.to(self.torch_dtype)
111+
return prompt_embeds, add_text_embeds
115112

116113
def forward_turbo(self, text_input_ids_1, text_input_ids_2):
117-
with torch.no_grad():
118-
prompt_embeds_1 = self.text_encoder_model_1(
119-
text_input_ids_1,
120-
output_hidden_states=True,
121-
)
122-
prompt_embeds_2 = self.text_encoder_model_2(
123-
text_input_ids_2,
124-
output_hidden_states=True,
125-
)
126-
# We are only ALWAYS interested in the pooled output of the final text encoder
127-
pooled_prompt_embeds = prompt_embeds_2[0]
114+
prompt_embeds_1 = self.text_encoder_model_1(
115+
text_input_ids_1,
116+
output_hidden_states=True,
117+
)
118+
prompt_embeds_2 = self.text_encoder_model_2(
119+
text_input_ids_2,
120+
output_hidden_states=True,
121+
)
122+
# We are only ALWAYS interested in the pooled output of the final text encoder
123+
pooled_prompt_embeds = prompt_embeds_2[0]
128124

129-
prompt_embeds_list = [
130-
prompt_embeds_1.hidden_states[-2],
131-
prompt_embeds_2.hidden_states[-2],
132-
]
133-
# neg_prompt_embeds_list = [
134-
# torch.zeros_like(prompt_embeds_list[0]), # dummy tensor
135-
# torch.zeros_like(prompt_embeds_list[1]), # dummy tensor
136-
# ]
125+
prompt_embeds_list = [
126+
prompt_embeds_1.hidden_states[-2],
127+
prompt_embeds_2.hidden_states[-2],
128+
]
129+
# neg_prompt_embeds_list = [
130+
# torch.zeros_like(prompt_embeds_list[0]), # dummy tensor
131+
# torch.zeros_like(prompt_embeds_list[1]), # dummy tensor
132+
# ]
137133

138-
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
134+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
139135

140-
bs_embed, seq_len, _ = prompt_embeds.shape
136+
bs_embed, seq_len, _ = prompt_embeds.shape
141137

142-
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
143-
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
144-
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(
145-
bs_embed * 1, -1
146-
)
147-
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
148-
add_text_embeds = pooled_prompt_embeds
149-
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
138+
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
139+
prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
140+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
141+
prompt_embeds = prompt_embeds.repeat(self.batch_size, 1, 1)
142+
add_text_embeds = pooled_prompt_embeds
143+
add_text_embeds = add_text_embeds.repeat(self.batch_size, 1)
150144

151-
add_text_embeds = add_text_embeds.to(self.torch_dtype)
152-
prompt_embeds = prompt_embeds.to(self.torch_dtype)
153-
return prompt_embeds, add_text_embeds
145+
add_text_embeds = add_text_embeds.to(self.torch_dtype)
146+
prompt_embeds = prompt_embeds.to(self.torch_dtype)
147+
return prompt_embeds, add_text_embeds
154148

155149

150+
@torch.no_grad()
156151
def export_prompt_encoder(
157152
hf_model_name,
158153
hf_auth_token=None,
@@ -171,7 +166,7 @@ def export_prompt_encoder(
171166
attn_spec=None,
172167
weights_only=False,
173168
batch_input=False,
174-
decomp_attn=False, # Compatibility
169+
decomp_attn=True,
175170
):
176171
do_classifier_free_guidance = True
177172

@@ -233,49 +228,63 @@ def export_prompt_encoder(
233228
if weights_only:
234229
return None, external_weight_path
235230

236-
class CompiledClip(CompiledModule):
231+
example_inputs = {
232+
"text_input_ids_1": torch.empty(
233+
(input_batchsize, max_length), dtype=torch.int64
234+
),
235+
"text_input_ids_2": torch.empty(
236+
(input_batchsize, max_length), dtype=torch.int64
237+
),
238+
"uncond_input_ids_1": torch.empty(
239+
(input_batchsize, max_length), dtype=torch.int64
240+
),
241+
"uncond_input_ids_2": torch.empty(
242+
(input_batchsize, max_length), dtype=torch.int64
243+
),
244+
}
245+
decomp_list = []
246+
decomp_attn = True
247+
if decomp_attn == True:
248+
decomp_list = [
249+
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
250+
torch.ops.aten._scaled_dot_product_flash_attention.default,
251+
torch.ops.aten.scaled_dot_product_attention,
252+
]
253+
with decompositions.extend_aot_decompositions(
254+
from_current=True,
255+
add_ops=decomp_list,
256+
):
237257
if external_weights:
238-
params = export_parameters(
239-
prompt_encoder_module,
240-
external=True,
241-
external_scope="",
242-
name_mapper=mapper.get,
243-
)
244-
else:
245-
params = export_parameters(prompt_encoder_module)
246-
247-
def encode_prompts(
248-
self,
249-
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
250-
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
251-
uc_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
252-
uc_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
253-
):
254-
return jittable(prompt_encoder_module.forward)(
255-
t_ids_1, t_ids_2, uc_ids_1, uc_ids_2
256-
)
257-
258-
def encode_prompts_turbo(
259-
self,
260-
t_ids_1=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
261-
t_ids_2=AbstractTensor(input_batchsize, max_length, dtype=torch.int64),
262-
):
263-
return jittable(prompt_encoder_module.forward_turbo)(t_ids_1, t_ids_2)
264-
265-
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
266-
inst = CompiledClip(context=Context(), import_to=import_to)
267-
268-
module = CompiledModule.get_mlir_module(inst)
258+
# Transformers (model source) registers position ids as non-persistent.
259+
# This causes externalization to think it's a user input, and since it's not,
260+
# we end up trying to do ops on a !torch.None instead of a tensor.
261+
for buffer_name, buffer in prompt_encoder_module.named_buffers(
262+
recurse=True
263+
):
264+
mod_name_list = buffer_name.split(".")
265+
buffer_id = mod_name_list.pop()
266+
parent = prompt_encoder_module
267+
for i in mod_name_list:
268+
parent = getattr(parent, i)
269+
parent.register_buffer(buffer_id, buffer, persistent=True)
270+
externalize_module_parameters(prompt_encoder_module)
271+
output = export(
272+
prompt_encoder_module,
273+
kwargs=example_inputs,
274+
module_name="compiled_clip",
275+
function_name="encode_prompts",
276+
)
277+
module = output.mlir_module
269278

270279
model_metadata_encode = {
271280
"model_name": hf_model_name + "_text_encoder",
272281
"input_shapes": [str((input_batchsize, max_length)) for i in range(4)],
273282
"input_dtypes": ["int64" for i in range(4)],
274283
"use_attention_mask": False,
275284
}
285+
276286
module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run()
277287
module_str = str(module)
278-
279288
if compile_to != "vmfb":
280289
return module_str
281290
else:

0 commit comments

Comments
 (0)