Skip to content

Commit 2aa7748

Browse files
authored
Update mimi export test
Differential Revision: D72091091 Pull Request resolved: #9755
1 parent eda319e commit 2aa7748

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

examples/models/moshi/mimi/test_mimi.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,23 @@ def __init__(self, mimi: nn.Module):
173173
self.mimi_model = mimi
174174

175175
def forward(self, x):
176-
return self.mimi_model.decode(x)
176+
x = x.transpose(1, 2)
177+
x = self.mimi_model.upsample(x)
178+
(emb,) = self.mimi_model.decoder_transformer(x)
179+
emb.transpose(1, 2)
180+
with self.mimi_model._context_for_encoder_decoder:
181+
out = self.mimi_model.decoder(emb)
182+
return out
177183

178-
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
179-
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
180-
chunk = sample_pcm[..., 0:pcm_chunk_size]
181-
input = self.mimi.encode(chunk)
184+
emb_input = torch.rand(1, 1, 512, device="cpu")
182185

183186
mimi_decode = MimiDecode(self.mimi)
184-
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
187+
mimi_decode.eval()
188+
mimi_decode(emb_input)
189+
190+
exported_decode: ExportedProgram = export(
191+
mimi_decode, (emb_input,), strict=False
192+
)
185193
quantization_config = get_symmetric_quantization_config(
186194
is_per_channel=True,
187195
is_dynamic=True,
@@ -190,12 +198,12 @@ def forward(self, x):
190198
quantizer.set_global(quantization_config)
191199
m = exported_decode.module()
192200
m = prepare_pt2e(m, quantizer)
193-
m(input)
201+
m(emb_input)
194202
m = convert_pt2e(m)
195203
print("quantized graph:")
196204
print(m.graph)
197205
# Export quantized module
198-
exported_decode: ExportedProgram = export(m, (input,), strict=False)
206+
exported_decode: ExportedProgram = export(m, (emb_input,), strict=False)
199207
# Lower
200208
edge_manager = to_edge_transform_and_lower(
201209
exported_decode,
@@ -208,16 +216,16 @@ def forward(self, x):
208216
with open(output_file, "wb") as file:
209217
exec_prog.write_to_file(file)
210218

211-
eager_res = mimi_decode(input)
219+
eager_res = mimi_decode(emb_input)
212220
runtime = Runtime.get()
213221
program = runtime.load_program(output_file)
214222
method = program.load_method("forward")
215-
flattened_x = tree_flatten(input)[0]
223+
flattened_x = tree_flatten(emb_input)[0]
216224
res = method.execute(flattened_x)
217225
# Compare results
218226
sqnr = compute_sqnr(eager_res, res[0])
219227
print(f"SQNR: {sqnr}")
220-
torch.testing.assert_close(eager_res, res[0], atol=1e-3, rtol=1e-3)
228+
torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3)
221229

222230

223231
if __name__ == "__main__":

0 commit comments

Comments
 (0)