Skip to content

Commit d36fdec

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Re-land D66465376 (#2637)
Summary: Pull Request resolved: #2637 Re-land diff D66465376 NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict` ``` def test_td_scripting(self) -> None: class TestModule(torch.nn.Module): torch.jit.ignore # <----- test fails without this ignore def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor: if isinstance(x, TensorDict): keys = list(x.keys()) return torch.cat([x[key]._values for key in keys], dim=0) else: return x._values m = TestModule() gm = torch.fx.symbolic_trace(m) jm = torch.jit.script(gm) values = torch.tensor([0, 1, 2, 3, 2, 3, 4]) kjt = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], values=values, offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), ) torch.testing.assert_allclose(jm(kjt), values) ``` Reviewed By: dstaay-fb Differential Revision: D66460392 fbshipit-source-id: 6fe35ebf2d1ebbac11b7cbba5efda1af1026028e
1 parent 6f4bfe2 commit d36fdec

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

torchrec/distributed/embedding.py

+8
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@
9797
except OSError:
9898
pass
9999

100+
try:
101+
from tensordict import TensorDict
102+
except ImportError:
103+
104+
class TensorDict:
105+
pass
106+
107+
100108
logger: logging.Logger = logging.getLogger(__name__)
101109

102110

torchrec/distributed/embeddingbag.py

+7
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@
102102
except OSError:
103103
pass
104104

105+
try:
106+
from tensordict import TensorDict
107+
except ImportError:
108+
109+
class TensorDict:
110+
pass
111+
105112

106113
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
107114
return (

torchrec/modules/embedding_modules.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
2222

2323

24+
try:
25+
from tensordict import TensorDict
26+
except ImportError:
27+
28+
class TensorDict:
29+
pass
30+
31+
2432
@torch.fx.wrap
2533
def reorder_inverse_indices(
2634
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],

torchrec/sparse/jagged_tensor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@
4949

5050
# OSS
5151
try:
52-
pass
52+
from tensordict import TensorDict
5353
except ImportError:
54-
pass
54+
55+
class TensorDict:
56+
pass
57+
5558

5659
logger: logging.Logger = logging.getLogger()
5760

0 commit comments

Comments
 (0)