Skip to content

Commit 1414d14

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Re-land D66465376 (#2637)
Summary: Re-land diff D66465376 NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict` ``` def test_td_scripting(self) -> None: class TestModule(torch.nn.Module): torch.jit.ignore # <----- test fails without this ignore def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor: if isinstance(x, TensorDict): keys = list(x.keys()) return torch.cat([x[key]._values for key in keys], dim=0) else: return x._values m = TestModule() gm = torch.fx.symbolic_trace(m) jm = torch.jit.script(gm) values = torch.tensor([0, 1, 2, 3, 2, 3, 4]) kjt = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], values=values, offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), ) torch.testing.assert_allclose(jm(kjt), values) ``` Differential Revision: D66460392
1 parent f059a49 commit 1414d14

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

Diff for: torchrec/distributed/embedding.py

+8
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@
9797
except OSError:
9898
pass
9999

100+
try:
101+
from tensordict import TensorDict
102+
except ImportError:
103+
104+
class TensorDict:
105+
pass
106+
107+
100108
logger: logging.Logger = logging.getLogger(__name__)
101109

102110

Diff for: torchrec/distributed/embeddingbag.py

+7
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@
9898
except OSError:
9999
pass
100100

101+
try:
102+
from tensordict import TensorDict
103+
except ImportError:
104+
105+
class TensorDict:
106+
pass
107+
101108

102109
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
103110
return (

Diff for: torchrec/modules/embedding_modules.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
2222

2323

24+
try:
25+
from tensordict import TensorDict
26+
except ImportError:
27+
28+
class TensorDict:
29+
pass
30+
31+
2432
@torch.fx.wrap
2533
def reorder_inverse_indices(
2634
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],

Diff for: torchrec/sparse/jagged_tensor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@
4949

5050
# OSS
5151
try:
52-
pass
52+
from tensordict import TensorDict
5353
except ImportError:
54-
pass
54+
55+
class TensorDict:
56+
pass
57+
5558

5659
logger: logging.Logger = logging.getLogger()
5760

0 commit comments

Comments
 (0)