File tree 1 file changed +15
-0
lines changed
1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 69
69
ShardedTensor ,
70
70
ShardingEnv ,
71
71
ShardMetadata ,
72
+ BoundsCheckMode ,
72
73
)
73
74
from torchrec .distributed .utils import (
74
75
add_params_from_parameter_sharding ,
@@ -1090,6 +1091,20 @@ def _dedup_indices(
1090
1091
for i , input_feature in enumerate (input_feature_splits ):
1091
1092
hash_size_cumsum = self .get_buffer (f"_hash_size_cumsum_tensor_{ i } " )
1092
1093
hash_size_offset = self .get_buffer (f"_hash_size_offset_tensor_{ i } " )
1094
+
1095
+ lookup = self ._lookups [i ]
1096
+ for emb_module in lookup ._emb_modules :
1097
+ emb_module = emb_module ._emb_module
1098
+ if emb_module .bounds_check_mode_int != BoundsCheckMode .NONE .value :
1099
+ torch .ops .fbgemm .bounds_check_indices (
1100
+ emb_module .rows_per_table ,
1101
+ input_feature .values ().long (),
1102
+ input_feature .offsets ().long (),
1103
+ emb_module .bounds_check_mode_int ,
1104
+ emb_module .bounds_check_warning ,
1105
+ None ,
1106
+ )
1107
+
1093
1108
(
1094
1109
lengths ,
1095
1110
offsets ,
You can’t perform that action at this time.
0 commit comments