Skip to content

Commit 7593bbd

Browse files
basilwongfacebook-github-bot
authored andcommitted
Update batched_embedding_kernel (#2702)
Summary: After this diff stack: EmbeddingKernelConfig now supports adding embedding_table_int32_index_type and embedding_table_int32_offset_type to the fused_params. These are used downstream to determine whether the indices and offsets types for split_table_batched_embeddings_ops_training.py Reviewed By: q10 Differential Revision: D66919716
1 parent 9269e73 commit 7593bbd

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

torchrec/distributed/batched_embedding_kernel.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,13 @@ def __init__(
764764
self._feature_table_map: List[int] = []
765765
self.table_name_to_count: Dict[str, int] = {}
766766
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
767+
self._fused_params: Dict[str, Any] = config.fused_params or {}
768+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
769+
"embedding_table_index_type", torch.int64
770+
)
771+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
772+
"embedding_table_offset_type", torch.int64
773+
)
767774

768775
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
769776
# `ShardedEmbeddingTable`.
@@ -805,8 +812,16 @@ def init_parameters(self) -> None:
805812

806813
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
807814
return self.emb_module(
808-
indices=features.values().long(),
809-
offsets=features.offsets().long(),
815+
indices=(
816+
features.values()
817+
if self._embedding_table_index_type == torch.int32
818+
else features.values().long()
819+
),
820+
offsets=(
821+
features.offsets().type(dtype=features.values().dtype)
822+
if self._embedding_table_offset_type == torch.int32
823+
else features.offsets().long()
824+
),
810825
)
811826

812827
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -1218,6 +1233,13 @@ def __init__(
12181233
self._lengths_per_emb: List[int] = []
12191234
self.table_name_to_count: Dict[str, int] = {}
12201235
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
1236+
self._fused_params: Dict[str, Any] = config.fused_params or {}
1237+
self._embedding_table_index_type: torch.dtype = self._fused_params.get(
1238+
"embedding_table_index_type", torch.int64
1239+
)
1240+
self._embedding_table_offset_type: torch.dtype = self._fused_params.get(
1241+
"embedding_table_offset_type", torch.int64
1242+
)
12211243

12221244
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
12231245
# `ShardedEmbeddingTable`.
@@ -1270,15 +1292,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
12701292
),
12711293
):
12721294
return self.emb_module(
1273-
indices=features.values().long(),
1274-
offsets=features.offsets().long(),
1295+
indices=(
1296+
features.values()
1297+
if self._embedding_table_index_type == torch.int32
1298+
else features.values().long()
1299+
),
1300+
offsets=(
1301+
features.offsets().type(dtype=features.values().dtype)
1302+
if self._embedding_table_offset_type == torch.int32
1303+
else features.offsets().long()
1304+
),
12751305
per_sample_weights=weights,
12761306
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
12771307
)
12781308
else:
12791309
return self.emb_module(
1280-
indices=features.values().long(),
1281-
offsets=features.offsets().long(),
1310+
indices=(
1311+
features.values()
1312+
if self._embedding_table_index_type == torch.int32
1313+
else features.values().long()
1314+
),
1315+
offsets=(
1316+
features.offsets().type(dtype=features.values().dtype)
1317+
if self._embedding_table_offset_type == torch.int32
1318+
else features.offsets().long()
1319+
),
12821320
per_sample_weights=weights,
12831321
)
12841322

0 commit comments

Comments
 (0)