Skip to content

Commit 0cb476e

Browse files
author
lixiaoguang12
committed
add bounds check before indice unique
1 parent 9269e73 commit 0cb476e

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchrec/distributed/embedding.py

+15
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
ShardedTensor,
7070
ShardingEnv,
7171
ShardMetadata,
72+
BoundsCheckMode,
7273
)
7374
from torchrec.distributed.utils import (
7475
add_params_from_parameter_sharding,
@@ -1090,6 +1091,20 @@ def _dedup_indices(
10901091
for i, input_feature in enumerate(input_feature_splits):
10911092
hash_size_cumsum = self.get_buffer(f"_hash_size_cumsum_tensor_{i}")
10921093
hash_size_offset = self.get_buffer(f"_hash_size_offset_tensor_{i}")
1094+
1095+
lookup = self._lookups[i]
1096+
for emb_module in lookup._emb_modules:
1097+
emb_module = emb_module._emb_module
1098+
if emb_module.bounds_check_mode_int != BoundsCheckMode.NONE.value:
1099+
torch.ops.fbgemm.bounds_check_indices(
1100+
emb_module.rows_per_table,
1101+
input_feature.values().long(),
1102+
input_feature.offsets().long(),
1103+
emb_module.bounds_check_mode_int,
1104+
emb_module.bounds_check_warning,
1105+
None,
1106+
)
1107+
10931108
(
10941109
lengths,
10951110
offsets,

0 commit comments

Comments
 (0)