@@ -1096,13 +1096,15 @@ def _maybe_compute_stride_kjt(
10961096 stride : Optional [int ],
10971097 lengths : Optional [torch .Tensor ],
10981098 offsets : Optional [torch .Tensor ],
1099- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1099+ stride_per_key_per_rank : Optional [torch . IntTensor ],
11001100) -> int :
11011101 if stride is None :
11021102 if len (keys ) == 0 :
11031103 stride = 0
1104- elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
1105- stride = max ([sum (s ) for s in stride_per_key_per_rank ])
1104+ elif (
1105+ stride_per_key_per_rank is not None and stride_per_key_per_rank .numel () > 0
1106+ ):
1107+ stride = int (stride_per_key_per_rank .sum (dim = 1 ).max ().item ())
11061108 elif offsets is not None and offsets .numel () > 0 :
11071109 stride = (offsets .numel () - 1 ) // len (keys )
11081110 elif lengths is not None :
@@ -1481,8 +1483,8 @@ def _strides_from_kjt(
14811483def _kjt_empty_like (kjt : "KeyedJaggedTensor" ) -> "KeyedJaggedTensor" :
14821484 # empty like function fx wrapped, also avoids device hardcoding
14831485 stride , stride_per_key_per_rank = (
1484- (None , kjt .stride_per_key_per_rank () )
1485- if kjt .variable_stride_per_key ()
1486+ (None , kjt ._stride_per_key_per_rank )
1487+ if kjt ._stride_per_key_per_rank is not None and kjt . variable_stride_per_key ()
14861488 else (kjt .stride (), None )
14871489 )
14881490
@@ -1668,14 +1670,20 @@ def _maybe_compute_lengths_offset_per_key(
16681670
16691671def _maybe_compute_stride_per_key (
16701672 stride_per_key : Optional [List [int ]],
1671- stride_per_key_per_rank : Optional [List [ List [ int ]] ],
1673+ stride_per_key_per_rank : Optional [torch . IntTensor ],
16721674 stride : Optional [int ],
16731675 keys : List [str ],
16741676) -> Optional [List [int ]]:
16751677 if stride_per_key is not None :
16761678 return stride_per_key
16771679 elif stride_per_key_per_rank is not None :
1678- return [sum (s ) for s in stride_per_key_per_rank ]
1680+ if stride_per_key_per_rank .dim () != 2 :
1681+ # after permute the kjt could be empty
1682+ return []
1683+ rt : List [int ] = stride_per_key_per_rank .sum (dim = 1 ).tolist ()
1684+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1685+ pt2_checks_all_is_size (rt )
1686+ return rt
16791687 elif stride is not None :
16801688 return [stride ] * len (keys )
16811689 else :
@@ -1766,7 +1774,9 @@ def __init__(
17661774 lengths : Optional [torch .Tensor ] = None ,
17671775 offsets : Optional [torch .Tensor ] = None ,
17681776 stride : Optional [int ] = None ,
1769- stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1777+ stride_per_key_per_rank : Optional [
1778+ Union [torch .IntTensor , List [List [int ]]]
1779+ ] = None ,
17701780 # Below exposed to ensure torch.script-able
17711781 stride_per_key : Optional [List [int ]] = None ,
17721782 length_per_key : Optional [List [int ]] = None ,
@@ -1788,8 +1798,14 @@ def __init__(
17881798 self ._lengths : Optional [torch .Tensor ] = lengths
17891799 self ._offsets : Optional [torch .Tensor ] = offsets
17901800 self ._stride : Optional [int ] = stride
1791- self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1792- stride_per_key_per_rank
1801+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1802+ # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
1803+ # does not take List[List[int]]
1804+ assert not isinstance (stride_per_key_per_rank , list )
1805+ self ._stride_per_key_per_rank : Optional [torch .IntTensor ] = (
1806+ torch .IntTensor (stride_per_key_per_rank , device = "cpu" )
1807+ if isinstance (stride_per_key_per_rank , list )
1808+ else stride_per_key_per_rank
17931809 )
17941810 self ._stride_per_key : Optional [List [int ]] = stride_per_key
17951811 self ._length_per_key : Optional [List [int ]] = length_per_key
@@ -1815,10 +1831,11 @@ def _init_pt2_checks(self) -> None:
18151831 return
18161832 if self ._stride_per_key is not None :
18171833 pt2_checks_all_is_size (self ._stride_per_key )
1818- if self ._stride_per_key_per_rank is not None :
1819- # pyre-ignore [16]
1820- for s in self ._stride_per_key_per_rank :
1821- pt2_checks_all_is_size (s )
1834+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
1835+ if _stride_per_key_per_rank is not None :
1836+ stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
1837+ for stride_per_rank in stride_per_key_per_rank :
1838+ pt2_checks_all_is_size (stride_per_rank )
18221839
18231840 @staticmethod
18241841 def from_offsets_sync (
@@ -2028,7 +2045,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor":
20282045 kjt_stride , kjt_stride_per_key_per_rank = (
20292046 (stride_per_key [0 ], None )
20302047 if all (s == stride_per_key [0 ] for s in stride_per_key )
2031- else (None , [[ stride ] for stride in stride_per_key ] )
2048+ else (None , torch . IntTensor ( stride_per_key , device = "cpu" ). reshape ( - 1 , 1 ) )
20322049 )
20332050 kjt = KeyedJaggedTensor (
20342051 keys = kjt_keys ,
@@ -2193,12 +2210,29 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
21932210 Returns:
21942211 List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
21952212 """
2196- stride_per_key_per_rank = self ._stride_per_key_per_rank
2197- return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
2213+ # making a local reference to the class variable to make jit.script behave
2214+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2215+ if (
2216+ not torch .jit .is_scripting ()
2217+ and is_torchdynamo_compiling ()
2218+ and _stride_per_key_per_rank is not None
2219+ ):
2220+ stride_per_key_per_rank = _stride_per_key_per_rank .tolist ()
2221+ for stride_per_rank in stride_per_key_per_rank :
2222+ pt2_checks_all_is_size (stride_per_rank )
2223+ return stride_per_key_per_rank
2224+ return (
2225+ []
2226+ if _stride_per_key_per_rank is None
2227+ else _stride_per_key_per_rank .tolist ()
2228+ )
21982229
21992230 def variable_stride_per_key (self ) -> bool :
22002231 """
22012232 Returns whether the KeyedJaggedTensor has variable stride per key.
2233+ NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank`
2234+ is not `None`. It might be assigned to False externally/intentionally, usually the
2235+ `self._stride_per_key_per_rank` is trivial.
22022236
22032237 Returns:
22042238 bool: whether the KeyedJaggedTensor has variable stride per key.
@@ -2343,13 +2377,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
23432377 start_offset = 0
23442378 _length_per_key = self .length_per_key ()
23452379 _offset_per_key = self .offset_per_key ()
2380+ # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script
2381+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
23462382 for segment in segments :
23472383 end = start + segment
23482384 end_offset = _offset_per_key [end ]
23492385 keys : List [str ] = self ._keys [start :end ]
23502386 stride_per_key_per_rank = (
2351- self . stride_per_key_per_rank () [start :end ]
2387+ _stride_per_key_per_rank [start :end , : ]
23522388 if self .variable_stride_per_key ()
2389+ and _stride_per_key_per_rank is not None
23532390 else None
23542391 )
23552392 if segment == len (self ._keys ):
@@ -2514,17 +2551,17 @@ def permute(
25142551
25152552 length_per_key = self .length_per_key ()
25162553 permuted_keys : List [str ] = []
2517- permuted_stride_per_key_per_rank : List [List [int ]] = []
25182554 permuted_length_per_key : List [int ] = []
25192555 permuted_length_per_key_sum = 0
25202556 for index in indices :
25212557 key = self .keys ()[index ]
25222558 permuted_keys .append (key )
25232559 permuted_length_per_key .append (length_per_key [index ])
2524- if self .variable_stride_per_key ():
2525- permuted_stride_per_key_per_rank .append (
2526- self .stride_per_key_per_rank ()[index ]
2527- )
2560+ _stride_per_key_per_rank = self ._stride_per_key_per_rank
2561+ if self .variable_stride_per_key () and _stride_per_key_per_rank is not None :
2562+ permuted_stride_per_key_per_rank = _stride_per_key_per_rank [indices , :]
2563+ else :
2564+ permuted_stride_per_key_per_rank = None
25282565
25292566 permuted_length_per_key_sum = sum (permuted_length_per_key )
25302567 if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2576,17 +2613,15 @@ def permute(
25762613 self .weights_or_none (),
25772614 permuted_length_per_key_sum ,
25782615 )
2579- stride_per_key_per_rank = (
2580- permuted_stride_per_key_per_rank if self .variable_stride_per_key () else None
2581- )
2616+
25822617 kjt = KeyedJaggedTensor (
25832618 keys = permuted_keys ,
25842619 values = permuted_values ,
25852620 weights = permuted_weights ,
25862621 lengths = permuted_lengths .view (- 1 ),
25872622 offsets = None ,
25882623 stride = self ._stride ,
2589- stride_per_key_per_rank = stride_per_key_per_rank ,
2624+ stride_per_key_per_rank = permuted_stride_per_key_per_rank ,
25902625 stride_per_key = None ,
25912626 length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
25922627 lengths_offset_per_key = None ,
@@ -2904,7 +2939,7 @@ def dist_init(
29042939
29052940 if variable_stride_per_key :
29062941 assert stride_per_rank_per_key is not None
2907- stride_per_key_per_rank_tensor : torch .Tensor = stride_per_rank_per_key .view (
2942+ stride_per_key_per_rank : torch .Tensor = stride_per_rank_per_key .view (
29082943 num_workers , len (keys )
29092944 ).T .cpu ()
29102945
@@ -2941,23 +2976,18 @@ def dist_init(
29412976 weights ,
29422977 )
29432978
2944- stride_per_key_per_rank = torch .jit .annotate (
2945- List [List [int ]], stride_per_key_per_rank_tensor .tolist ()
2946- )
2979+ if stride_per_key_per_rank .numel () == 0 :
2980+ stride_per_key_per_rank = torch .zeros (
2981+ (len (keys ), 1 ), device = "cpu" , dtype = torch .int64
2982+ )
29472983
2948- if not stride_per_key_per_rank :
2949- stride_per_key_per_rank = [[0 ]] * len (keys )
29502984 if stagger > 1 :
2951- stride_per_key_per_rank_stagger : List [List [int ]] = []
29522985 local_world_size = num_workers // stagger
2953- for i in range (len (keys )):
2954- stride_per_rank_stagger : List [int ] = []
2955- for j in range (local_world_size ):
2956- stride_per_rank_stagger .extend (
2957- stride_per_key_per_rank [i ][j ::local_world_size ]
2958- )
2959- stride_per_key_per_rank_stagger .append (stride_per_rank_stagger )
2960- stride_per_key_per_rank = stride_per_key_per_rank_stagger
2986+ indices = [
2987+ list (range (i , num_workers , local_world_size ))
2988+ for i in range (local_world_size )
2989+ ]
2990+ stride_per_key_per_rank = stride_per_key_per_rank [:, indices ]
29612991
29622992 kjt = KeyedJaggedTensor (
29632993 keys = keys ,
0 commit comments