Skip to content

Commit b3e19e2

Browse files
hstonecfacebook-github-bot
authored andcommitted
Cache embedding_weights_by_table for EmbeddingFusedOptimizer (#2711)
Summary: Pull Request resolved: #2711 The `split_embedding_weights()` method in the `emb_module` is a time-consuming operation. Currently, it is placed in the constructor of the `EmbeddingFusedOptimizer`. As a result, every time an `EmbeddingFusedOptimizer` instance is created, this method is executed. Since `_gen_named_parameters_by_table_fused` generates EmbeddingFusedOptimizer instances **thousands of times in a loop**, a significant amount of time is spent executing this method. By extracting this operation out of the loop and passing it as a parameter to achieve a caching effect, we can save a lot of time. Specifically, the current **CREATE_TRAIN_MODULE.SHARD_MODEL** takes approximately **22 seconds** to run, but with this caching mechanism, the runtime can be reduced to around **15 seconds**. The AI Lab result shows 6.67s saving(https://www.internalfb.com/family_of_labs/test_results/689499818) Reviewed By: dstaay-fb Differential Revision: D68578829 fbshipit-source-id: 63332203dfaec1f326298d068101cf66885b3394
1 parent 0d827ea commit b3e19e2

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchrec/distributed/batched_embedding_kernel.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import itertools
1414
import logging
1515
import tempfile
16-
from collections import OrderedDict
1716
from dataclasses import dataclass
1817
from typing import (
1918
Any,
@@ -216,6 +215,7 @@ def __init__( # noqa C901
216215
pg: Optional[dist.ProcessGroup] = None,
217216
create_for_table: Optional[str] = None,
218217
param_weight_for_table: Optional[nn.Parameter] = None,
218+
embedding_weights_by_table: Optional[List[torch.Tensor]] = None,
219219
) -> None:
220220
"""
221221
Implementation of a FusedOptimizer. Designed as a base class Embedding kernels
@@ -391,7 +391,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
391391
# that state_dict look identical to no-fused version.
392392
table_to_shard_params: Dict[str, ShardParams] = {}
393393

394-
embedding_weights_by_table = emb_module.split_embedding_weights()
394+
embedding_weights_by_table = (
395+
embedding_weights_by_table or emb_module.split_embedding_weights()
396+
)
395397

396398
all_optimizer_states = emb_module.get_optimizer_state()
397399
optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}
@@ -674,6 +676,8 @@ def _gen_named_parameters_by_table_fused(
674676
pg: Optional[dist.ProcessGroup] = None,
675677
) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]:
676678
# TODO: move logic to FBGEMM to avoid accessing fbgemm internals
679+
# Cache embedding_weights_by_table
680+
embedding_weights_by_table = emb_module.split_embedding_weights()
677681
for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs):
678682
table_name = config.embedding_tables[t_idx].name
679683
if table_name not in table_name_to_count:
@@ -709,6 +713,7 @@ def _gen_named_parameters_by_table_fused(
709713
pg=pg,
710714
create_for_table=table_name,
711715
param_weight_for_table=weight,
716+
embedding_weights_by_table=embedding_weights_by_table,
712717
)
713718
]
714719
yield (table_name, weight)

0 commit comments

Comments
 (0)