Skip to content

Commit 4b731b3

Browse files
basilwongfacebook-github-bot
authored andcommitted
Update batched_embedding_kernel (#2702)
Summary: After this diff stack: EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params. These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py Differential Revision: D66919716
1 parent 52b0749 commit 4b731b3

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

torchrec/distributed/batched_embedding_kernel.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,13 @@ def __init__(
759759
self._feature_table_map: List[int] = []
760760
self.table_name_to_count: Dict[str, int] = {}
761761
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
762+
self._fused_params: Dict[str, Any] = config.fused_params or {}
763+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
764+
"embedding_table_index_type", torch.int64
765+
)
766+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
767+
"embedding_table_offset_type", torch.int64
768+
)
762769

763770
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
764771
# `ShardedEmbeddingTable`.
@@ -800,8 +807,16 @@ def init_parameters(self) -> None:
800807

801808
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
802809
return self.emb_module(
803-
indices=features.values().long(),
804-
offsets=features.offsets().long(),
810+
indices=(
811+
features.values()
812+
if self._embedding_table_index_type == torch.int32
813+
else features.values().long()
814+
),
815+
offsets=(
816+
features.offsets().type(dtype=features.values().dtype)
817+
if self._embedding_table_offset_type == torch.int32
818+
else features.offsets().long()
819+
),
805820
)
806821

807822
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -1213,6 +1228,13 @@ def __init__(
12131228
self._lengths_per_emb: List[int] = []
12141229
self.table_name_to_count: Dict[str, int] = {}
12151230
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
1231+
self._fused_params: Dict[str, Any] = config.fused_params or {}
1232+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
1233+
"embedding_table_index_type", torch.int64
1234+
)
1235+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
1236+
"embedding_table_offset_type", torch.int64
1237+
)
12161238

12171239
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
12181240
# `ShardedEmbeddingTable`.
@@ -1265,15 +1287,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
12651287
),
12661288
):
12671289
return self.emb_module(
1268-
indices=features.values().long(),
1269-
offsets=features.offsets().long(),
1290+
indices=(
1291+
features.values()
1292+
if self._embedding_table_index_type == torch.int32
1293+
else features.values().long()
1294+
),
1295+
offsets=(
1296+
features.offsets().type(dtype=features.values().dtype)
1297+
if self._embedding_table_offset_type == torch.int32
1298+
else features.offsets().long()
1299+
),
12701300
per_sample_weights=weights,
12711301
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
12721302
)
12731303
else:
12741304
return self.emb_module(
1275-
indices=features.values().long(),
1276-
offsets=features.offsets().long(),
1305+
indices=(
1306+
features.values()
1307+
if self._embedding_table_index_type == torch.int32
1308+
else features.values().long()
1309+
),
1310+
offsets=(
1311+
features.offsets().type(dtype=features.values().dtype)
1312+
if self._embedding_table_offset_type == torch.int32
1313+
else features.offsets().long()
1314+
),
12771315
per_sample_weights=weights,
12781316
)
12791317

0 commit comments

Comments
 (0)