47
47
from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
48
48
from torchrec .streamable import Multistreamable
49
49
50
+
50
51
torch .fx .wrap ("len" )
51
52
52
53
CACHE_LOAD_FACTOR_STR : str = "cache_load_factor"
@@ -61,6 +62,15 @@ def _fx_wrap_tensor_to_device_dtype(
61
62
return t .to (device = tensor_device_dtype .device , dtype = tensor_device_dtype .dtype )
62
63
63
64
65
+ @torch .fx .wrap
66
+ def _fx_wrap_optional_tensor_to_device_dtype (
67
+ t : Optional [torch .Tensor ], tensor_device_dtype : torch .Tensor
68
+ ) -> Optional [torch .Tensor ]:
69
+ if t is None :
70
+ return None
71
+ return t .to (device = tensor_device_dtype .device , dtype = tensor_device_dtype .dtype )
72
+
73
+
64
74
@torch .fx .wrap
65
75
def _fx_wrap_batch_size_per_feature (kjt : KeyedJaggedTensor ) -> Optional [torch .Tensor ]:
66
76
return (
@@ -121,6 +131,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
121
131
block_sizes : torch .Tensor ,
122
132
bucketize_pos : bool = False ,
123
133
block_bucketize_pos : Optional [List [torch .Tensor ]] = None ,
134
+ total_num_blocks : Optional [torch .Tensor ] = None ,
124
135
) -> Tuple [
125
136
torch .Tensor ,
126
137
torch .Tensor ,
@@ -142,6 +153,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
142
153
bucketize_pos = bucketize_pos ,
143
154
sequence = True ,
144
155
block_sizes = block_sizes ,
156
+ total_num_blocks = total_num_blocks ,
145
157
my_size = num_buckets ,
146
158
weights = kjt .weights_or_none (),
147
159
max_B = _fx_wrap_max_B (kjt ),
@@ -289,6 +301,7 @@ def bucketize_kjt_inference(
289
301
kjt : KeyedJaggedTensor ,
290
302
num_buckets : int ,
291
303
block_sizes : torch .Tensor ,
304
+ total_num_buckets : Optional [torch .Tensor ] = None ,
292
305
bucketize_pos : bool = False ,
293
306
block_bucketize_row_pos : Optional [List [torch .Tensor ]] = None ,
294
307
is_sequence : bool = False ,
@@ -303,6 +316,7 @@ def bucketize_kjt_inference(
303
316
Args:
304
317
num_buckets (int): number of buckets to bucketize the values into.
305
318
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
319
+ total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
306
320
bucketize_pos (bool): output the changed position of the bucketized values or
307
321
not.
308
322
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
@@ -318,6 +332,9 @@ def bucketize_kjt_inference(
318
332
f"Expecting block sizes for { num_features } features, but { block_sizes .numel ()} received." ,
319
333
)
320
334
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype (block_sizes , kjt .values ())
335
+ total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype (
336
+ total_num_buckets , kjt .values ()
337
+ )
321
338
unbucketize_permute = None
322
339
bucket_mapping = None
323
340
if is_sequence :
@@ -332,6 +349,7 @@ def bucketize_kjt_inference(
332
349
kjt ,
333
350
num_buckets = num_buckets ,
334
351
block_sizes = block_sizes_new_type ,
352
+ total_num_blocks = total_num_buckets_new_type ,
335
353
bucketize_pos = bucketize_pos ,
336
354
block_bucketize_pos = block_bucketize_row_pos ,
337
355
)
0 commit comments