Skip to content

Commit 7be5368

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Re-land D66465376 (#2637)
Summary: Re-land diff D66465376 NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict` ``` def test_td_scripting(self) -> None: class TestModule(torch.nn.Module): torch.jit.ignore # <----- test fails without this ignore def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor: if isinstance(x, TensorDict): keys = list(x.keys()) return torch.cat([x[key]._values for key in keys], dim=0) else: return x._values m = TestModule() gm = torch.fx.symbolic_trace(m) jm = torch.jit.script(gm) values = torch.tensor([0, 1, 2, 3, 2, 3, 4]) kjt = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], values=values, offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), ) torch.testing.assert_allclose(jm(kjt), values) ``` Reviewed By: dstaay-fb Differential Revision: D66460392
1 parent 00d8ed2 commit 7be5368

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

torchrec/distributed/embedding.py

+8
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@
9797
except OSError:
9898
pass
9999

100+
try:
101+
from tensordict import TensorDict
102+
except ImportError:
103+
104+
class TensorDict:
105+
pass
106+
107+
100108
logger: logging.Logger = logging.getLogger(__name__)
101109

102110

torchrec/distributed/embeddingbag.py

+7
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@
100100
except OSError:
101101
pass
102102

103+
try:
104+
from tensordict import TensorDict
105+
except ImportError:
106+
107+
class TensorDict:
108+
pass
109+
103110

104111
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
105112
return (

torchrec/modules/embedding_modules.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
2222

2323

24+
try:
25+
from tensordict import TensorDict
26+
except ImportError:
27+
28+
class TensorDict:
29+
pass
30+
31+
2432
@torch.fx.wrap
2533
def reorder_inverse_indices(
2634
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],

torchrec/sparse/jagged_tensor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@
4949

5050
# OSS
5151
try:
52-
pass
52+
from tensordict import TensorDict
5353
except ImportError:
54-
pass
54+
55+
class TensorDict:
56+
pass
57+
5558

5659
logger: logging.Logger = logging.getLogger()
5760

0 commit comments

Comments
 (0)