Skip to content

Commit f059a49

Browse files
duduyi2013facebook-github-bot
authored andcommitted
add ssd-emo checkpoint support for sEC (#2650)
Summary: Pull Request resolved: #2650 as title, we only add support for sEBC for ssd-emo, this diff also add support for sEC Reviewed By: sarckk, jiayulu Differential Revision: D67183728 fbshipit-source-id: 8d5d32e33520405bc838dc9bc7759d2540593618
1 parent 75307b1 commit f059a49

File tree

5 files changed

+148
-41
lines changed

5 files changed

+148
-41
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ def __init__(
900900
pg,
901901
)
902902
self._param_per_table: Dict[str, nn.Parameter] = dict(
903-
_gen_named_parameters_by_table_ssd(
903+
_gen_named_parameters_by_table_ssd_pmt(
904904
emb_module=self._emb_module,
905905
table_name_to_count=self.table_name_to_count.copy(),
906906
config=self._config,
@@ -933,11 +933,31 @@ def state_dict(
933933
destination: Optional[Dict[str, Any]] = None,
934934
prefix: str = "",
935935
keep_vars: bool = False,
936+
no_snapshot: bool = True,
936937
) -> Dict[str, Any]:
937-
if destination is None:
938-
destination = OrderedDict()
938+
"""
939+
Args:
940+
no_snapshot (bool): the tensors in the returned dict are
941+
PartiallyMaterializedTensors. this argument controls wether the
942+
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
943+
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
944+
PartiallyMaterializedTensor has a RocksDB snapshot handle
945+
"""
946+
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
947+
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
939948

940-
return destination
949+
emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
950+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
951+
for emb_table in emb_table_config_copy:
952+
emb_table.local_metadata.placement._device = torch.device("cpu")
953+
ret = get_state_dict(
954+
emb_table_config_copy,
955+
emb_tables,
956+
self._pg,
957+
destination,
958+
prefix,
959+
)
960+
return ret
941961

942962
def named_parameters(
943963
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
@@ -950,14 +970,16 @@ def named_parameters(
950970
):
951971
# hack before we support optimizer on sharded parameter level
952972
# can delete after PEA deprecation
973+
# pyre-ignore [6]
953974
param = nn.Parameter(tensor)
954975
# pyre-ignore
955976
param._in_backward_optimizers = [EmptyFusedOptimizer()]
956977
yield name, param
957978

979+
# pyre-ignore [15]
958980
def named_split_embedding_weights(
959981
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
960-
) -> Iterator[Tuple[str, torch.Tensor]]:
982+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
961983
assert (
962984
remove_duplicate
963985
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
@@ -968,6 +990,21 @@ def named_split_embedding_weights(
968990
key = append_prefix(prefix, f"{config.name}.weight")
969991
yield key, tensor
970992

993+
def get_named_split_embedding_weights_snapshot(
994+
self, prefix: str = ""
995+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
996+
"""
997+
Return an iterator over embedding tables, yielding both the table name as well as the embedding
998+
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
999+
RocksDB snapshot to support windowed access.
1000+
"""
1001+
for config, tensor in zip(
1002+
self._config.embedding_tables,
1003+
self.split_embedding_weights(no_snapshot=False),
1004+
):
1005+
key = append_prefix(prefix, f"{config.name}")
1006+
yield key, tensor
1007+
9711008
def flush(self) -> None:
9721009
"""
9731010
Flush the embeddings in cache back to SSD. Should be pretty expensive.
@@ -982,11 +1019,11 @@ def purge(self) -> None:
9821019
self.emb_module.lxu_cache_weights.zero_()
9831020
self.emb_module.lxu_cache_state.fill_(-1)
9841021

985-
def split_embedding_weights(self) -> List[torch.Tensor]:
986-
"""
987-
Return fake tensors.
988-
"""
989-
return [param.data for param in self._param_per_table.values()]
1022+
# pyre-ignore [15]
1023+
def split_embedding_weights(
1024+
self, no_snapshot: bool = True
1025+
) -> List[PartiallyMaterializedTensor]:
1026+
return self.emb_module.split_embedding_weights(no_snapshot)
9901027

9911028

9921029
class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):

torchrec/distributed/embedding.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ def _initialize_torch_state(self) -> None: # noqa
744744
self._model_parallel_name_to_shards_wrapper = OrderedDict()
745745
self._model_parallel_name_to_sharded_tensor = OrderedDict()
746746
self._model_parallel_name_to_dtensor = OrderedDict()
747-
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
747+
_model_parallel_name_to_compute_kernel: Dict[str, str] = {}
748748
for (
749749
table_name,
750750
parameter_sharding,
@@ -755,7 +755,7 @@ def _initialize_torch_state(self) -> None: # noqa
755755
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
756756
[("local_tensors", []), ("local_offsets", [])]
757757
)
758-
model_parallel_name_to_compute_kernel[table_name] = (
758+
_model_parallel_name_to_compute_kernel[table_name] = (
759759
parameter_sharding.compute_kernel
760760
)
761761

@@ -813,18 +813,17 @@ def _initialize_torch_state(self) -> None: # noqa
813813
"weight", nn.Parameter(torch.empty(0))
814814
)
815815
if (
816-
model_parallel_name_to_compute_kernel[table_name]
816+
_model_parallel_name_to_compute_kernel[table_name]
817817
!= EmbeddingComputeKernel.DENSE.value
818818
):
819819
self.embeddings[table_name].weight._in_backward_optimizers = [
820820
EmptyFusedOptimizer()
821821
]
822822

823-
if model_parallel_name_to_compute_kernel[table_name] in {
824-
EmbeddingComputeKernel.KEY_VALUE.value
825-
}:
826-
continue
827823
if self._output_dtensor:
824+
assert _model_parallel_name_to_compute_kernel[table_name] not in {
825+
EmbeddingComputeKernel.KEY_VALUE.value
826+
}
828827
if shards_wrapper_map["local_tensors"]:
829828
self._model_parallel_name_to_dtensor[table_name] = (
830829
DTensor.from_local(
@@ -853,6 +852,8 @@ def _initialize_torch_state(self) -> None: # noqa
853852
)
854853
else:
855854
# created ShardedTensors once in init, use in post_state_dict_hook
855+
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
856+
# access is allowed on them.
856857
self._model_parallel_name_to_sharded_tensor[table_name] = (
857858
ShardedTensor._init_from_local_shards(
858859
local_shards,
@@ -861,6 +862,21 @@ def _initialize_torch_state(self) -> None: # noqa
861862
)
862863
)
863864

865+
def extract_sharded_kvtensors(
866+
module: ShardedEmbeddingCollection,
867+
) -> OrderedDict[str, ShardedTensor]:
868+
# retrieve all kvstore backed tensors
869+
ret = OrderedDict()
870+
for (
871+
table_name,
872+
sharded_t,
873+
) in module._model_parallel_name_to_sharded_tensor.items():
874+
if _model_parallel_name_to_compute_kernel[table_name] in {
875+
EmbeddingComputeKernel.KEY_VALUE.value
876+
}:
877+
ret[table_name] = sharded_t
878+
return ret
879+
864880
def post_state_dict_hook(
865881
module: ShardedEmbeddingCollection,
866882
destination: Dict[str, torch.Tensor],
@@ -881,6 +897,28 @@ def post_state_dict_hook(
881897
destination_key = f"{prefix}embeddings.{table_name}.weight"
882898
destination[destination_key] = d_tensor
883899

900+
# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
901+
# snapshot for read access.
902+
sharded_kvtensors = extract_sharded_kvtensors(module)
903+
if len(sharded_kvtensors) == 0:
904+
return
905+
906+
sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
907+
for lookup, sharding_type in zip(
908+
module._lookups, module._sharding_type_to_sharding.keys()
909+
):
910+
if sharding_type != ShardingType.DATA_PARALLEL.value:
911+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
912+
for key, v in lookup.get_named_split_embedding_weights_snapshot():
913+
assert key in sharded_kvtensors_copy
914+
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
915+
for (
916+
table_name,
917+
sharded_kvtensor,
918+
) in sharded_kvtensors_copy.items():
919+
destination_key = f"{prefix}embeddings.{table_name}.weight"
920+
destination[destination_key] = sharded_kvtensor
921+
884922
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
885923
self._register_state_dict_hook(post_state_dict_hook)
886924
self._register_load_state_dict_pre_hook(

torchrec/distributed/embedding_lookup.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,18 @@ def named_parameters_by_table(
348348
) in embedding_kernel.named_parameters_by_table():
349349
yield (table_name, tbe_slice)
350350

351+
def get_named_split_embedding_weights_snapshot(
352+
self,
353+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
354+
"""
355+
Return an iterator over embedding tables, yielding both the table name as well as the embedding
356+
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
357+
RocksDB snapshot to support windowed access.
358+
"""
359+
for emb_module in self._emb_modules:
360+
if isinstance(emb_module, KeyValueEmbedding):
361+
yield from emb_module.get_named_split_embedding_weights_snapshot()
362+
351363
def flush(self) -> None:
352364
for emb_module in self._emb_modules:
353365
emb_module.flush()

torchrec/distributed/embeddingbag.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def _initialize_torch_state(self) -> None: # noqa
825825
self._model_parallel_name_to_sharded_tensor = OrderedDict()
826826
self._model_parallel_name_to_dtensor = OrderedDict()
827827

828-
self._model_parallel_name_to_compute_kernel: Dict[str, str] = {}
828+
_model_parallel_name_to_compute_kernel: Dict[str, str] = {}
829829
for (
830830
table_name,
831831
parameter_sharding,
@@ -836,7 +836,7 @@ def _initialize_torch_state(self) -> None: # noqa
836836
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
837837
[("local_tensors", []), ("local_offsets", [])]
838838
)
839-
self._model_parallel_name_to_compute_kernel[table_name] = (
839+
_model_parallel_name_to_compute_kernel[table_name] = (
840840
parameter_sharding.compute_kernel
841841
)
842842

@@ -892,15 +892,15 @@ def _initialize_torch_state(self) -> None: # noqa
892892
"weight", nn.Parameter(torch.empty(0))
893893
)
894894
if (
895-
self._model_parallel_name_to_compute_kernel[table_name]
895+
_model_parallel_name_to_compute_kernel[table_name]
896896
!= EmbeddingComputeKernel.DENSE.value
897897
):
898898
self.embedding_bags[table_name].weight._in_backward_optimizers = [
899899
EmptyFusedOptimizer()
900900
]
901901

902902
if self._output_dtensor:
903-
assert self._model_parallel_name_to_compute_kernel[table_name] not in {
903+
assert _model_parallel_name_to_compute_kernel[table_name] not in {
904904
EmbeddingComputeKernel.KEY_VALUE.value
905905
}
906906
if shards_wrapper_map["local_tensors"]:
@@ -954,7 +954,7 @@ def extract_sharded_kvtensors(
954954
table_name,
955955
sharded_t,
956956
) in module._model_parallel_name_to_sharded_tensor.items():
957-
if self._model_parallel_name_to_compute_kernel[table_name] in {
957+
if _model_parallel_name_to_compute_kernel[table_name] in {
958958
EmbeddingComputeKernel.KEY_VALUE.value
959959
}:
960960
ret[table_name] = sharded_t
@@ -983,15 +983,14 @@ def post_state_dict_hook(
983983
# kvstore backed tensors do not have a valid backing snapshot at this point. Fill in a valid
984984
# snapshot for read access.
985985
sharded_kvtensors = extract_sharded_kvtensors(module)
986+
if len(sharded_kvtensors) == 0:
987+
return
988+
986989
sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
987990
for lookup, sharding in zip(module._lookups, module._embedding_shardings):
988-
if isinstance(sharding, DpPooledEmbeddingSharding):
989-
# unwrap DDP
990-
lookup = lookup.module
991-
else:
991+
if not isinstance(sharding, DpPooledEmbeddingSharding):
992992
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
993993
for key, v in lookup.get_named_split_embedding_weights_snapshot():
994-
destination_key = f"{prefix}embedding_bags.{key}.weight"
995994
assert key in sharded_kvtensors_copy
996995
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
997996
for (

torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,23 @@ def _copy_ssd_emb_modules(
111111
"SSDEmbeddingBag or SSDEmbeddingBag."
112112
)
113113

114-
weights = emb_module1.emb_module.debug_split_embedding_weights()
115-
# need to set emb_module1 as well, since otherwise emb_module1 would
116-
# produce a random debug_split_embedding_weights everytime
117-
_load_split_embedding_weights(emb_module1, weights)
118-
_load_split_embedding_weights(emb_module2, weights)
114+
emb1_kv = dict(
115+
emb_module1.get_named_split_embedding_weights_snapshot()
116+
)
117+
for (
118+
k,
119+
v,
120+
) in emb_module2.get_named_split_embedding_weights_snapshot():
121+
v1 = emb1_kv.get(k)
122+
v1_full_tensor = v1.full_tensor()
123+
124+
# write value into ssd for both emb module for later comparison
125+
v.wrapped.set_range(
126+
0, 0, v1_full_tensor.size(0), v1_full_tensor
127+
)
128+
v1.wrapped.set_range(
129+
0, 0, v1_full_tensor.size(0), v1_full_tensor
130+
)
119131

120132
# purge after loading. This is needed, since we pass a batch
121133
# through dmp when instantiating them.
@@ -141,10 +153,12 @@ def _copy_ssd_emb_modules(
141153
sharding_type=st.sampled_from(
142154
[
143155
ShardingType.TABLE_WISE.value,
144-
ShardingType.COLUMN_WISE.value,
156+
# TODO: uncomment when ssd ckpt support cw sharding
157+
# ShardingType.COLUMN_WISE.value,
145158
ShardingType.ROW_WISE.value,
146159
ShardingType.TABLE_ROW_WISE.value,
147-
ShardingType.TABLE_COLUMN_WISE.value,
160+
# TODO: uncomment when ssd ckpt support cw sharding
161+
# ShardingType.TABLE_COLUMN_WISE.value,
148162
]
149163
),
150164
is_training=st.booleans(),
@@ -220,10 +234,12 @@ def test_ssd_load_state_dict(
220234
sharding_type=st.sampled_from(
221235
[
222236
ShardingType.TABLE_WISE.value,
223-
ShardingType.COLUMN_WISE.value,
237+
# TODO: uncomment when ssd ckpt support cw sharding
238+
# ShardingType.COLUMN_WISE.value,
224239
ShardingType.ROW_WISE.value,
225240
ShardingType.TABLE_ROW_WISE.value,
226-
ShardingType.TABLE_COLUMN_WISE.value,
241+
# TODO: uncomment when ssd ckpt support cw sharding
242+
# ShardingType.TABLE_COLUMN_WISE.value,
227243
]
228244
),
229245
is_training=st.booleans(),
@@ -344,10 +360,12 @@ def test_ssd_tbe_numerical_accuracy(
344360
sharding_type=st.sampled_from(
345361
[
346362
ShardingType.TABLE_WISE.value,
347-
ShardingType.COLUMN_WISE.value,
363+
# TODO: uncomment when ssd ckpt support cw sharding
364+
# ShardingType.COLUMN_WISE.value,
348365
ShardingType.ROW_WISE.value,
349366
ShardingType.TABLE_ROW_WISE.value,
350-
ShardingType.TABLE_COLUMN_WISE.value,
367+
# TODO: uncomment when ssd ckpt support cw sharding
368+
# ShardingType.TABLE_COLUMN_WISE.value,
351369
]
352370
),
353371
is_training=st.booleans(),
@@ -455,10 +473,12 @@ def test_ssd_fused_optimizer(
455473
sharding_type=st.sampled_from(
456474
[
457475
ShardingType.TABLE_WISE.value,
458-
ShardingType.COLUMN_WISE.value,
476+
# TODO: uncomment when ssd ckpt support cw sharding
477+
# ShardingType.COLUMN_WISE.value,
459478
ShardingType.ROW_WISE.value,
460479
ShardingType.TABLE_ROW_WISE.value,
461-
ShardingType.TABLE_COLUMN_WISE.value,
480+
# TODO: uncomment when ssd ckpt support cw sharding
481+
# ShardingType.TABLE_COLUMN_WISE.value,
462482
]
463483
),
464484
is_training=st.booleans(),
@@ -682,7 +702,8 @@ def _copy_ssd_emb_modules(
682702
sharding_type=st.sampled_from(
683703
[
684704
ShardingType.TABLE_WISE.value,
685-
ShardingType.COLUMN_WISE.value,
705+
# TODO: uncomment when ssd ckpt support cw sharding
706+
# ShardingType.COLUMN_WISE.value,
686707
ShardingType.ROW_WISE.value,
687708
]
688709
),

0 commit comments

Comments
 (0)