@@ -179,10 +179,37 @@ def get(self) -> int:
179
179
def set (self , val ):
180
180
self .counter_ = val
181
181
182
+ @torch ._library .register_fake_class ("fbgemm::TensorQueue" )
183
+ class FakeTensorQueue :
184
+ def __init__ (self , queue , init_tensor ):
185
+ self .queue = queue
186
+ self .init_tensor = init_tensor
187
+
188
+ @classmethod
189
+ def __obj_unflatten__ (cls , flattened_ctx ):
190
+ return cls (** dict (flattened_ctx ))
191
+
192
+ def push (self , x ):
193
+ self .queue .append (x )
194
+
195
+ def pop (self ):
196
+ if len (self .queue ) == 0 :
197
+ return self .init_tensor
198
+ return self .queue .pop (0 )
199
+
200
+ def top (self ):
201
+ if len (self .queue ) == 0 :
202
+ return self .init_tensor
203
+ return self .queue [0 ]
204
+
205
+ def size (self ):
206
+ return len (self .queue )
207
+
182
208
def tearDown (self ):
183
209
torch ._library .fake_class_registry .deregister_fake_class (
184
210
"fbgemm::AtomicCounter"
185
211
)
212
+ torch ._library .fake_class_registry .deregister_fake_class ("fbgemm::TensorQueue" )
186
213
super ().tearDown ()
187
214
188
215
def _test_kjt_input_module (
@@ -517,7 +544,7 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
517
544
{},
518
545
strict = False ,
519
546
pre_dispatch = True ,
520
- ). run_decompositions ()
547
+ )
521
548
522
549
ep .module ()(kjt .values (), kjt .lengths ())
523
550
@@ -556,16 +583,17 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:
556
583
{},
557
584
strict = False ,
558
585
pre_dispatch = True ,
559
- ). run_decompositions ()
586
+ )
560
587
ep .module ()(kjt .values (), kjt .lengths ())
561
588
562
589
# PT2 IR autofunctionalizes mutation funcs (bounds_check_indices)
563
590
# ensure such node isn't present, as it causes issues with IR
564
591
for n in ep .graph_module .graph .nodes :
565
592
self .assertFalse ("auto_functionalized" in str (n .name ))
566
593
567
- # TODO: Fix Unflatten
568
- # torch.export.unflatten(ep)
594
+ torch .export .unflatten (ep )
595
+
596
+ ep (kjt .values (), kjt .lengths ())
569
597
570
598
def test_maybe_compute_kjt_to_jt_dict (self ) -> None :
571
599
kjt : KeyedJaggedTensor = make_kjt ([2 , 3 , 4 , 5 , 6 ], [1 , 2 , 1 , 1 ])
0 commit comments