Skip to content

Commit 8afcff3

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
OSS fix train_pipelines utils test (#2114)
Summary: Pull Request resolved: #2114 Reviewed By: dstaay-fb Differential Revision: D58594592 fbshipit-source-id: aeeeee09d14d9558f53d52ff68b5b00cfcf8565f
1 parent 6337072 commit 8afcff3

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

Diff for: torchrec/distributed/tests/test_pt2.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,37 @@ def get(self) -> int:
179179
def set(self, val):
180180
self.counter_ = val
181181

182+
@torch._library.register_fake_class("fbgemm::TensorQueue")
183+
class FakeTensorQueue:
184+
def __init__(self, queue, init_tensor):
185+
self.queue = queue
186+
self.init_tensor = init_tensor
187+
188+
@classmethod
189+
def __obj_unflatten__(cls, flattened_ctx):
190+
return cls(**dict(flattened_ctx))
191+
192+
def push(self, x):
193+
self.queue.append(x)
194+
195+
def pop(self):
196+
if len(self.queue) == 0:
197+
return self.init_tensor
198+
return self.queue.pop(0)
199+
200+
def top(self):
201+
if len(self.queue) == 0:
202+
return self.init_tensor
203+
return self.queue[0]
204+
205+
def size(self):
206+
return len(self.queue)
207+
182208
def tearDown(self):
183209
torch._library.fake_class_registry.deregister_fake_class(
184210
"fbgemm::AtomicCounter"
185211
)
212+
torch._library.fake_class_registry.deregister_fake_class("fbgemm::TensorQueue")
186213
super().tearDown()
187214

188215
def _test_kjt_input_module(
@@ -517,7 +544,7 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
517544
{},
518545
strict=False,
519546
pre_dispatch=True,
520-
).run_decompositions()
547+
)
521548

522549
ep.module()(kjt.values(), kjt.lengths())
523550

@@ -556,16 +583,17 @@ def test_sharded_quant_fpebc_non_strict_export(self) -> None:
556583
{},
557584
strict=False,
558585
pre_dispatch=True,
559-
).run_decompositions()
586+
)
560587
ep.module()(kjt.values(), kjt.lengths())
561588

562589
# PT2 IR autofunctionalizes mutation funcs (bounds_check_indices)
563590
# ensure such node isn't present, as it causes issues with IR
564591
for n in ep.graph_module.graph.nodes:
565592
self.assertFalse("auto_functionalized" in str(n.name))
566593

567-
# TODO: Fix Unflatten
568-
# torch.export.unflatten(ep)
594+
torch.export.unflatten(ep)
595+
596+
ep(kjt.values(), kjt.lengths())
569597

570598
def test_maybe_compute_kjt_to_jt_dict(self) -> None:
571599
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

Diff for: torchrec/distributed/train_pipeline/tests/test_utils.py renamed to torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1616

1717

18-
class TestUtils(unittest.TestCase):
18+
class TestTrainPipelineUtils(unittest.TestCase):
1919
def test_get_node_args_helper_call_module_kjt(self) -> None:
2020
graph = torch.fx.Graph()
2121
kjt_args = []

0 commit comments

Comments
 (0)