@@ -173,15 +173,23 @@ def __init__(self, mimi: nn.Module):
173
173
self .mimi_model = mimi
174
174
175
175
def forward (self , x ):
176
- return self .mimi_model .decode (x )
176
+ x = x .transpose (1 , 2 )
177
+ x = self .mimi_model .upsample (x )
178
+ (emb ,) = self .mimi_model .decoder_transformer (x )
179
+ emb .transpose (1 , 2 )
180
+ with self .mimi_model ._context_for_encoder_decoder :
181
+ out = self .mimi_model .decoder (emb )
182
+ return out
177
183
178
- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
179
- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
180
- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
181
- input = self .mimi .encode (chunk )
184
+ emb_input = torch .rand (1 , 1 , 512 , device = "cpu" )
182
185
183
186
mimi_decode = MimiDecode (self .mimi )
184
- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
187
+ mimi_decode .eval ()
188
+ mimi_decode (emb_input )
189
+
190
+ exported_decode : ExportedProgram = export (
191
+ mimi_decode , (emb_input ,), strict = False
192
+ )
185
193
quantization_config = get_symmetric_quantization_config (
186
194
is_per_channel = True ,
187
195
is_dynamic = True ,
@@ -190,12 +198,12 @@ def forward(self, x):
190
198
quantizer .set_global (quantization_config )
191
199
m = exported_decode .module ()
192
200
m = prepare_pt2e (m , quantizer )
193
- m (input )
201
+ m (emb_input )
194
202
m = convert_pt2e (m )
195
203
print ("quantized graph:" )
196
204
print (m .graph )
197
205
# Export quantized module
198
- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
206
+ exported_decode : ExportedProgram = export (m , (emb_input ,), strict = False )
199
207
# Lower
200
208
edge_manager = to_edge_transform_and_lower (
201
209
exported_decode ,
@@ -208,16 +216,16 @@ def forward(self, x):
208
216
with open (output_file , "wb" ) as file :
209
217
exec_prog .write_to_file (file )
210
218
211
- eager_res = mimi_decode (input )
219
+ eager_res = mimi_decode (emb_input )
212
220
runtime = Runtime .get ()
213
221
program = runtime .load_program (output_file )
214
222
method = program .load_method ("forward" )
215
- flattened_x = tree_flatten (input )[0 ]
223
+ flattened_x = tree_flatten (emb_input )[0 ]
216
224
res = method .execute (flattened_x )
217
225
# Compare results
218
226
sqnr = compute_sqnr (eager_res , res [0 ])
219
227
print (f"SQNR: { sqnr } " )
220
- torch .testing .assert_close (eager_res , res [0 ], atol = 1e -3 , rtol = 1e-3 )
228
+ torch .testing .assert_close (eager_res , res [0 ], atol = 4e -3 , rtol = 1e-3 )
221
229
222
230
223
231
if __name__ == "__main__" :
0 commit comments