@@ -48,111 +48,106 @@ def __init__(
4848 def forward (
4949 self , text_input_ids_1 , text_input_ids_2 , uncond_input_ids_1 , uncond_input_ids_2
5050 ):
51- with torch .no_grad ():
52- prompt_embeds_1 = self .text_encoder_model_1 (
53- text_input_ids_1 ,
54- output_hidden_states = True ,
55- )
56- prompt_embeds_2 = self .text_encoder_model_2 (
57- text_input_ids_2 ,
58- output_hidden_states = True ,
59- )
60- neg_prompt_embeds_1 = self .text_encoder_model_1 (
61- uncond_input_ids_1 ,
62- output_hidden_states = True ,
63- )
64- neg_prompt_embeds_2 = self .text_encoder_model_2 (
65- uncond_input_ids_2 ,
66- output_hidden_states = True ,
67- )
68- # We are only ALWAYS interested in the pooled output of the final text encoder
69- pooled_prompt_embeds = prompt_embeds_2 [0 ]
70- neg_pooled_prompt_embeds = neg_prompt_embeds_2 [0 ]
71-
72- prompt_embeds_list = [
73- prompt_embeds_1 .hidden_states [- 2 ],
74- prompt_embeds_2 .hidden_states [- 2 ],
75- ]
76- neg_prompt_embeds_list = [
77- neg_prompt_embeds_1 .hidden_states [- 2 ],
78- neg_prompt_embeds_2 .hidden_states [- 2 ],
79- ]
80-
81- prompt_embeds = torch .concat (prompt_embeds_list , dim = - 1 )
82- neg_prompt_embeds = torch .concat (neg_prompt_embeds_list , dim = - 1 )
83-
84- bs_embed , seq_len , _ = prompt_embeds .shape
85- prompt_embeds = prompt_embeds .repeat (1 , 1 , 1 )
86- prompt_embeds = prompt_embeds .view (bs_embed * 1 , seq_len , - 1 )
87- pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , 1 ).view (
88- bs_embed * 1 , - 1
89- )
51+ prompt_embeds_1 = self .text_encoder_model_1 (
52+ text_input_ids_1 ,
53+ output_hidden_states = True ,
54+ )
55+ prompt_embeds_2 = self .text_encoder_model_2 (
56+ text_input_ids_2 ,
57+ output_hidden_states = True ,
58+ )
59+ neg_prompt_embeds_1 = self .text_encoder_model_1 (
60+ uncond_input_ids_1 ,
61+ output_hidden_states = True ,
62+ )
63+ neg_prompt_embeds_2 = self .text_encoder_model_2 (
64+ uncond_input_ids_2 ,
65+ output_hidden_states = True ,
66+ )
67+ # We are only ALWAYS interested in the pooled output of the final text encoder
68+ pooled_prompt_embeds = prompt_embeds_2 [0 ]
69+ neg_pooled_prompt_embeds = neg_prompt_embeds_2 [0 ]
70+
71+ prompt_embeds_list = [
72+ prompt_embeds_1 .hidden_states [- 2 ],
73+ prompt_embeds_2 .hidden_states [- 2 ],
74+ ]
75+ neg_prompt_embeds_list = [
76+ neg_prompt_embeds_1 .hidden_states [- 2 ],
77+ neg_prompt_embeds_2 .hidden_states [- 2 ],
78+ ]
79+
80+ prompt_embeds = torch .concat (prompt_embeds_list , dim = - 1 )
81+ neg_prompt_embeds = torch .concat (neg_prompt_embeds_list , dim = - 1 )
82+
83+ bs_embed , seq_len , _ = prompt_embeds .shape
84+ prompt_embeds = prompt_embeds .repeat (1 , 1 , 1 )
85+ prompt_embeds = prompt_embeds .view (bs_embed * 1 , seq_len , - 1 )
86+ pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , 1 ).view (bs_embed * 1 , - 1 )
87+ if not self .batch_input :
88+ prompt_embeds = prompt_embeds .repeat (self .batch_size , 1 , 1 )
89+ add_text_embeds = pooled_prompt_embeds
90+ if not self .batch_input :
91+ add_text_embeds = add_text_embeds .repeat (self .batch_size , 1 )
92+ if self .do_classifier_free_guidance :
9093 if not self .batch_input :
91- prompt_embeds = prompt_embeds .repeat (self .batch_size , 1 , 1 )
92- add_text_embeds = pooled_prompt_embeds
94+ neg_pooled_prompt_embeds = neg_pooled_prompt_embeds .repeat (1 , 1 ).view (
95+ 1 , - 1
96+ )
97+ neg_prompt_embeds = neg_prompt_embeds .repeat (1 , 1 , 1 )
98+ neg_prompt_embeds = neg_prompt_embeds .view (bs_embed * 1 , seq_len , - 1 )
99+ if not self .batch_input :
100+ neg_prompt_embeds = neg_prompt_embeds .repeat (self .batch_size , 1 , 1 )
101+ prompt_embeds = torch .cat ([neg_prompt_embeds , prompt_embeds ], dim = 0 )
93102 if not self .batch_input :
94- add_text_embeds = add_text_embeds .repeat (self .batch_size , 1 )
95- if self .do_classifier_free_guidance :
96- if not self .batch_input :
97- neg_pooled_prompt_embeds = neg_pooled_prompt_embeds .repeat (
98- 1 , 1
99- ).view (1 , - 1 )
100- neg_prompt_embeds = neg_prompt_embeds .repeat (1 , 1 , 1 )
101- neg_prompt_embeds = neg_prompt_embeds .view (bs_embed * 1 , seq_len , - 1 )
102- if not self .batch_input :
103- neg_prompt_embeds = neg_prompt_embeds .repeat (self .batch_size , 1 , 1 )
104- prompt_embeds = torch .cat ([neg_prompt_embeds , prompt_embeds ], dim = 0 )
105- if not self .batch_input :
106- neg_pooled_prompt_embeds = neg_pooled_prompt_embeds .repeat (
107- self .batch_size , 1
108- )
109- add_text_embeds = torch .cat (
110- [neg_pooled_prompt_embeds , add_text_embeds ], dim = 0
103+ neg_pooled_prompt_embeds = neg_pooled_prompt_embeds .repeat (
104+ self .batch_size , 1
111105 )
112- add_text_embeds = add_text_embeds .to (self .torch_dtype )
113- prompt_embeds = prompt_embeds .to (self .torch_dtype )
114- return prompt_embeds , add_text_embeds
106+ add_text_embeds = torch .cat (
107+ [neg_pooled_prompt_embeds , add_text_embeds ], dim = 0
108+ )
109+ add_text_embeds = add_text_embeds .to (self .torch_dtype )
110+ prompt_embeds = prompt_embeds .to (self .torch_dtype )
111+ return prompt_embeds , add_text_embeds
115112
116113 def forward_turbo (self , text_input_ids_1 , text_input_ids_2 ):
117- with torch .no_grad ():
118- prompt_embeds_1 = self .text_encoder_model_1 (
119- text_input_ids_1 ,
120- output_hidden_states = True ,
121- )
122- prompt_embeds_2 = self .text_encoder_model_2 (
123- text_input_ids_2 ,
124- output_hidden_states = True ,
125- )
126- # We are only ALWAYS interested in the pooled output of the final text encoder
127- pooled_prompt_embeds = prompt_embeds_2 [0 ]
114+ prompt_embeds_1 = self .text_encoder_model_1 (
115+ text_input_ids_1 ,
116+ output_hidden_states = True ,
117+ )
118+ prompt_embeds_2 = self .text_encoder_model_2 (
119+ text_input_ids_2 ,
120+ output_hidden_states = True ,
121+ )
122+ # We are only ALWAYS interested in the pooled output of the final text encoder
123+ pooled_prompt_embeds = prompt_embeds_2 [0 ]
128124
129- prompt_embeds_list = [
130- prompt_embeds_1 .hidden_states [- 2 ],
131- prompt_embeds_2 .hidden_states [- 2 ],
132- ]
133- # neg_prompt_embeds_list = [
134- # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor
135- # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor
136- # ]
125+ prompt_embeds_list = [
126+ prompt_embeds_1 .hidden_states [- 2 ],
127+ prompt_embeds_2 .hidden_states [- 2 ],
128+ ]
129+ # neg_prompt_embeds_list = [
130+ # torch.zeros_like(prompt_embeds_list[0]), # dummy tensor
131+ # torch.zeros_like(prompt_embeds_list[1]), # dummy tensor
132+ # ]
137133
138- prompt_embeds = torch .concat (prompt_embeds_list , dim = - 1 )
134+ prompt_embeds = torch .concat (prompt_embeds_list , dim = - 1 )
139135
140- bs_embed , seq_len , _ = prompt_embeds .shape
136+ bs_embed , seq_len , _ = prompt_embeds .shape
141137
142- prompt_embeds = prompt_embeds .repeat (1 , 1 , 1 )
143- prompt_embeds = prompt_embeds .view (bs_embed * 1 , seq_len , - 1 )
144- pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , 1 ).view (
145- bs_embed * 1 , - 1
146- )
147- prompt_embeds = prompt_embeds .repeat (self .batch_size , 1 , 1 )
148- add_text_embeds = pooled_prompt_embeds
149- add_text_embeds = add_text_embeds .repeat (self .batch_size , 1 )
138+ prompt_embeds = prompt_embeds .repeat (1 , 1 , 1 )
139+ prompt_embeds = prompt_embeds .view (bs_embed * 1 , seq_len , - 1 )
140+ pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , 1 ).view (bs_embed * 1 , - 1 )
141+ prompt_embeds = prompt_embeds .repeat (self .batch_size , 1 , 1 )
142+ add_text_embeds = pooled_prompt_embeds
143+ add_text_embeds = add_text_embeds .repeat (self .batch_size , 1 )
150144
151- add_text_embeds = add_text_embeds .to (self .torch_dtype )
152- prompt_embeds = prompt_embeds .to (self .torch_dtype )
153- return prompt_embeds , add_text_embeds
145+ add_text_embeds = add_text_embeds .to (self .torch_dtype )
146+ prompt_embeds = prompt_embeds .to (self .torch_dtype )
147+ return prompt_embeds , add_text_embeds
154148
155149
150+ @torch .no_grad ()
156151def export_prompt_encoder (
157152 hf_model_name ,
158153 hf_auth_token = None ,
@@ -171,7 +166,7 @@ def export_prompt_encoder(
171166 attn_spec = None ,
172167 weights_only = False ,
173168 batch_input = False ,
174- decomp_attn = False , # Compatibility
169+ decomp_attn = True ,
175170):
176171 do_classifier_free_guidance = True
177172
@@ -233,49 +228,63 @@ def export_prompt_encoder(
233228 if weights_only :
234229 return None , external_weight_path
235230
236- class CompiledClip (CompiledModule ):
231+ example_inputs = {
232+ "text_input_ids_1" : torch .empty (
233+ (input_batchsize , max_length ), dtype = torch .int64
234+ ),
235+ "text_input_ids_2" : torch .empty (
236+ (input_batchsize , max_length ), dtype = torch .int64
237+ ),
238+ "uncond_input_ids_1" : torch .empty (
239+ (input_batchsize , max_length ), dtype = torch .int64
240+ ),
241+ "uncond_input_ids_2" : torch .empty (
242+ (input_batchsize , max_length ), dtype = torch .int64
243+ ),
244+ }
245+ decomp_list = []
246+ decomp_attn = True
247+ if decomp_attn == True :
248+ decomp_list = [
249+ torch .ops .aten ._scaled_dot_product_flash_attention_for_cpu ,
250+ torch .ops .aten ._scaled_dot_product_flash_attention .default ,
251+ torch .ops .aten .scaled_dot_product_attention ,
252+ ]
253+ with decompositions .extend_aot_decompositions (
254+ from_current = True ,
255+ add_ops = decomp_list ,
256+ ):
237257 if external_weights :
238- params = export_parameters (
239- prompt_encoder_module ,
240- external = True ,
241- external_scope = "" ,
242- name_mapper = mapper .get ,
243- )
244- else :
245- params = export_parameters (prompt_encoder_module )
246-
247- def encode_prompts (
248- self ,
249- t_ids_1 = AbstractTensor (input_batchsize , max_length , dtype = torch .int64 ),
250- t_ids_2 = AbstractTensor (input_batchsize , max_length , dtype = torch .int64 ),
251- uc_ids_1 = AbstractTensor (input_batchsize , max_length , dtype = torch .int64 ),
252- uc_ids_2 = AbstractTensor (input_batchsize , max_length , dtype = torch .int64 ),
253- ):
254- return jittable (prompt_encoder_module .forward )(
255- t_ids_1 , t_ids_2 , uc_ids_1 , uc_ids_2
256- )
257-
258- def encode_prompts_turbo (
259- self ,
260- t_ids_1 = AbstractTensor (input_batchsize , max_length , dtype = torch .int64 ),
261- t_ids_2 = AbstractTensor (input_batchsize , max_length , dtype = torch .int64 ),
262- ):
263- return jittable (prompt_encoder_module .forward_turbo )(t_ids_1 , t_ids_2 )
264-
265- import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
266- inst = CompiledClip (context = Context (), import_to = import_to )
267-
268- module = CompiledModule .get_mlir_module (inst )
258+ # Transformers (model source) registers position ids as non-persistent.
259+ # This causes externalization to think it's a user input, and since it's not,
260+ # we end up trying to do ops on a !torch.None instead of a tensor.
261+ for buffer_name , buffer in prompt_encoder_module .named_buffers (
262+ recurse = True
263+ ):
264+ mod_name_list = buffer_name .split ("." )
265+ buffer_id = mod_name_list .pop ()
266+ parent = prompt_encoder_module
267+ for i in mod_name_list :
268+ parent = getattr (parent , i )
269+ parent .register_buffer (buffer_id , buffer , persistent = True )
270+ externalize_module_parameters (prompt_encoder_module )
271+ output = export (
272+ prompt_encoder_module ,
273+ kwargs = example_inputs ,
274+ module_name = "compiled_clip" ,
275+ function_name = "encode_prompts" ,
276+ )
277+ module = output .mlir_module
269278
270279 model_metadata_encode = {
271280 "model_name" : hf_model_name + "_text_encoder" ,
272281 "input_shapes" : [str ((input_batchsize , max_length )) for i in range (4 )],
273282 "input_dtypes" : ["int64" for i in range (4 )],
274283 "use_attention_mask" : False ,
275284 }
285+
276286 module = AddMetadataPass (module , model_metadata_encode , "encode_prompts" ).run ()
277287 module_str = str (module )
278-
279288 if compile_to != "vmfb" :
280289 return module_str
281290 else :
0 commit comments