From a0a458bca8669d9235b3b0754d9417ed1aa3a802 Mon Sep 17 00:00:00 2001 From: James Dong Date: Tue, 8 Apr 2025 11:28:46 -0700 Subject: [PATCH] Refactor stride_per_key_per_rank to support torch.Tensor Summary: `stride_per_key_per_rank` should be a variable whose value is dynamically after input_dist. Updating its type to `Union[Optional[torch.Tensor], Optional[List[List[int]]]]` to be backward compatible. Differential Revision: D72658640 --- torchrec/sparse/jagged_tensor.py | 39 +++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 2bbe09149..6388cbe33 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -548,7 +548,7 @@ def _kjt_concat( lengths=torch.cat(length_list, dim=0), stride=stride, stride_per_key_per_rank=( - stride_per_key_per_rank if variable_stride_per_key else None + torch.tensor(stride_per_key_per_rank) if variable_stride_per_key else None ), length_per_key=length_per_key if has_length_per_key else None, inverse_indices=( @@ -1096,7 +1096,7 @@ def _maybe_compute_stride_kjt( stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Union[Optional[torch.Tensor], Optional[List[List[int]]]], ) -> int: if stride is None: if len(keys) == 0: @@ -1668,7 +1668,7 @@ def _maybe_compute_lengths_offset_per_key( def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Union[Optional[torch.Tensor], Optional[List[List[int]]]], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: @@ -1684,7 +1684,7 @@ def _maybe_compute_stride_per_key( def _maybe_compute_variable_stride_per_key( variable_stride_per_key: Optional[bool], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Union[Optional[torch.Tensor], Optional[List[List[int]]]], ) -> bool: if variable_stride_per_key is not None: return variable_stride_per_key @@ -1766,7 +1766,9 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Union[ + Optional[torch.Tensor], Optional[List[List[int]]] + ] = None, # Below exposed to ensure torch.script-able stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, @@ -1788,9 +1790,9 @@ def __init__( self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets self._stride: Optional[int] = stride - self._stride_per_key_per_rank: Optional[List[List[int]]] = ( - stride_per_key_per_rank - ) + self._stride_per_key_per_rank: Union[ + Optional[torch.Tensor], Optional[List[List[int]]] + ] = stride_per_key_per_rank self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key self._offset_per_key: Optional[List[int]] = offset_per_key @@ -1827,7 +1829,9 @@ def from_offsets_sync( offsets: torch.Tensor, weights: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Union[ + Optional[torch.Tensor], Optional[List[List[int]]] + ] = None, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> "KeyedJaggedTensor": """ @@ -1840,7 +1844,7 @@ def from_offsets_sync( weights (Optional[torch.Tensor]): if the values have weights. Tensor with the same shape as values. stride (Optional[int]): number of examples per batch. - stride_per_key_per_rank (Optional[List[List[int]]]): batch size + stride_per_key_per_rank (Union[Optional[torch.Tensor], Optional[List[List[int]]]]): batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values. inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to @@ -1867,7 +1871,9 @@ def from_lengths_sync( lengths: torch.Tensor, weights: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Union[ + Optional[torch.Tensor], Optional[List[List[int]]] + ] = None, inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, ) -> "KeyedJaggedTensor": """ @@ -1881,7 +1887,7 @@ def from_lengths_sync( weights (Optional[torch.Tensor]): if the values have weights. Tensor with the same shape as values. stride (Optional[int]): number of examples per batch. - stride_per_key_per_rank (Optional[List[List[int]]]): batch size + stride_per_key_per_rank (Union[Optional[torch.Tensor], Optional[List[List[int]]]]): batch size (number of examples) per key per rank, with the outer list representing the keys and the inner list representing the values. inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to @@ -2193,8 +2199,15 @@ def stride_per_key_per_rank(self) -> List[List[int]]: Returns: List[List[int]]: stride per key per rank of the KeyedJaggedTensor. """ + if self._stride_per_key_per_rank is None: + return [] + stride_per_key_per_rank = self._stride_per_key_per_rank - return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + return ( + stride_per_key_per_rank.tolist() + if isinstance(stride_per_key_per_rank, torch.Tensor) + else stride_per_key_per_rank + ) def variable_stride_per_key(self) -> bool: """