From 00478cdef0e32a8bfb8ea5d1574f570ddc722c0e Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Wed, 4 Dec 2024 09:54:20 -0800 Subject: [PATCH] Use CPU inputs for PT2 inference full model export Summary: Enable TorchRec + PT2 inference full model export, switching over path to use CPU inputs + SplitDispatchMode instead of meta inputs, which face issues later on down the stack. Differential Revision: D62468529 --- torchrec/sparse/jagged_tensor.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index fa28309e3..8468c9977 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1175,7 +1175,12 @@ def _maybe_compute_length_per_key( values: Optional[torch.Tensor], ) -> List[int]: if length_per_key is None: - if len(keys) and values is not None and values.is_meta: + if ( + len(keys) + and values is not None + and values.is_meta + and not is_non_strict_exporting() + ): # create dummy lengths per key when on meta device total_length = values.numel() _length = [total_length // len(keys)] * len(keys)