|
13 | 13 | import itertools
|
14 | 14 | import logging
|
15 | 15 | import tempfile
|
16 |
| -from collections import OrderedDict |
17 | 16 | from dataclasses import dataclass
|
18 | 17 | from typing import (
|
19 | 18 | Any,
|
@@ -216,6 +215,7 @@ def __init__( # noqa C901
|
216 | 215 | pg: Optional[dist.ProcessGroup] = None,
|
217 | 216 | create_for_table: Optional[str] = None,
|
218 | 217 | param_weight_for_table: Optional[nn.Parameter] = None,
|
| 218 | + embedding_weights_by_table: Optional[List[torch.Tensor]] = None, |
219 | 219 | ) -> None:
|
220 | 220 | """
|
221 | 221 | Implementation of a FusedOptimizer. Designed as a base class Embedding kernels
|
@@ -391,7 +391,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
|
391 | 391 | # that state_dict look identical to no-fused version.
|
392 | 392 | table_to_shard_params: Dict[str, ShardParams] = {}
|
393 | 393 |
|
394 |
| - embedding_weights_by_table = emb_module.split_embedding_weights() |
| 394 | + embedding_weights_by_table = ( |
| 395 | + embedding_weights_by_table or emb_module.split_embedding_weights() |
| 396 | + ) |
395 | 397 |
|
396 | 398 | all_optimizer_states = emb_module.get_optimizer_state()
|
397 | 399 | optimizer_states_keys_by_table: Dict[str, List[torch.Tensor]] = {}
|
@@ -674,6 +676,8 @@ def _gen_named_parameters_by_table_fused(
|
674 | 676 | pg: Optional[dist.ProcessGroup] = None,
|
675 | 677 | ) -> Iterator[Tuple[str, TableBatchedEmbeddingSlice]]:
|
676 | 678 | # TODO: move logic to FBGEMM to avoid accessing fbgemm internals
|
| 679 | + # Cache embedding_weights_by_table |
| 680 | + embedding_weights_by_table = emb_module.split_embedding_weights() |
677 | 681 | for t_idx, (rows, dim, location, _) in enumerate(emb_module.embedding_specs):
|
678 | 682 | table_name = config.embedding_tables[t_idx].name
|
679 | 683 | if table_name not in table_name_to_count:
|
@@ -709,6 +713,7 @@ def _gen_named_parameters_by_table_fused(
|
709 | 713 | pg=pg,
|
710 | 714 | create_for_table=table_name,
|
711 | 715 | param_weight_for_table=weight,
|
| 716 | + embedding_weights_by_table=embedding_weights_by_table, |
712 | 717 | )
|
713 | 718 | ]
|
714 | 719 | yield (table_name, weight)
|
|
0 commit comments