@@ -271,34 +271,34 @@ def test_dynamic_shape_ebc(self) -> None:
271271 # Serialize EBC
272272 collection = mark_dynamic_kjt (feature1 )
273273 model , sparse_fqns = encapsulate_ir_modules (model , JsonSerializer )
274- ep = torch .export .export (
275- model ,
276- (feature1 ,),
277- {},
278- dynamic_shapes = collection .dynamic_shapes (model , (feature1 ,)),
279- strict = False ,
280- # Allows KJT to not be unflattened and run a forward on unflattened EP
281- preserve_module_call_signature = tuple (sparse_fqns ),
282- )
283-
284- # Run forward on ExportedProgram
285- ep_output = ep .module ()(feature2 )
286-
287- # other asserts
288- for i , tensor in enumerate (ep_output ):
289- self .assertEqual (eager_out [i ].shape , tensor .shape )
290-
291- # Deserialize EBC
292- unflatten_ep = torch .export .unflatten (ep )
293- deserialized_model = decapsulate_ir_modules (unflatten_ep , JsonSerializer )
294- deserialized_model .load_state_dict (model .state_dict ())
295-
296- # Run forward on deserialized model
297- deserialized_out = deserialized_model (feature2 )
298-
299- for i , tensor in enumerate (deserialized_out ):
300- self .assertEqual (eager_out [i ].shape , tensor .shape )
301- assert torch .allclose (eager_out [i ], tensor )
274+ # ep = torch.export.export(
275+ # model,
276+ # (feature1,),
277+ # {},
278+ # dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
279+ # strict=False,
280+ # # Allows KJT to not be unflattened and run a forward on unflattened EP
281+ # preserve_module_call_signature=tuple(sparse_fqns),
282+ # )
283+
284+ # # Run forward on ExportedProgram
285+ # ep_output = ep.module()(feature2)
286+
287+ # # other asserts
288+ # for i, tensor in enumerate(ep_output):
289+ # self.assertEqual(eager_out[i].shape, tensor.shape)
290+
291+ # # Deserialize EBC
292+ # unflatten_ep = torch.export.unflatten(ep)
293+ # deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
294+ # deserialized_model.load_state_dict(model.state_dict())
295+
296+ # # Run forward on deserialized model
297+ # deserialized_out = deserialized_model(feature2)
298+
299+ # for i, tensor in enumerate(deserialized_out):
300+ # self.assertEqual(eager_out[i].shape, tensor.shape)
301+ # assert torch.allclose(eager_out[i], tensor)
302302
303303 def test_ir_emb_lookup_device (self ) -> None :
304304 model = self .generate_model ()
0 commit comments