@@ -179,10 +179,37 @@ def get(self) -> int:
179179 def set (self , val ):
180180 self .counter_ = val
181181
182+ @torch ._library .register_fake_class ("fbgemm::TensorQueue" )
183+ class FakeTensorQueue :
184+ def __init__ (self , queue , init_tensor ):
185+ self .queue = queue
186+ self .init_tensor = init_tensor
187+
188+ @classmethod
189+ def __obj_unflatten__ (cls , flattened_ctx ):
190+ return cls (** dict (flattened_ctx ))
191+
192+ def push (self , x ):
193+ self .queue .append (x )
194+
195+ def pop (self ):
196+ if len (self .queue ) == 0 :
197+ return self .init_tensor
198+ return self .queue .pop (0 )
199+
200+ def top (self ):
201+ if len (self .queue ) == 0 :
202+ return self .init_tensor
203+ return self .queue [0 ]
204+
205+ def size (self ):
206+ return len (self .queue )
207+
182208 def tearDown (self ):
183209 torch ._library .fake_class_registry .deregister_fake_class (
184210 "fbgemm::AtomicCounter"
185211 )
212+ torch ._library .fake_class_registry .deregister_fake_class ("fbgemm::TensorQueue" )
186213 super ().tearDown ()
187214
188215 def _test_kjt_input_module (
@@ -517,7 +544,7 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
517544 {},
518545 strict = False ,
519546 pre_dispatch = True ,
520- ). run_decompositions ()
547+ )
521548
522549 ep .module ()(kjt .values (), kjt .lengths ())
523550
@@ -556,16 +583,17 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:
556583 {},
557584 strict = False ,
558585 pre_dispatch = True ,
559- ). run_decompositions ()
586+ )
560587 ep .module ()(kjt .values (), kjt .lengths ())
561588
562589 # PT2 IR autofunctionalizes mutation funcs (bounds_check_indices)
563590 # ensure such node isn't present, as it causes issues with IR
564591 for n in ep .graph_module .graph .nodes :
565592 self .assertFalse ("auto_functionalized" in str (n .name ))
566593
567- # TODO: Fix Unflatten
568- # torch.export.unflatten(ep)
594+ torch .export .unflatten (ep )
595+
596+ ep (kjt .values (), kjt .lengths ())
569597
570598 def test_maybe_compute_kjt_to_jt_dict (self ) -> None :
571599 kjt : KeyedJaggedTensor = make_kjt ([2 , 3 , 4 , 5 , 6 ], [1 , 2 , 1 , 1 ])
0 commit comments