Skip to content

Commit d945603

Browse files
emlinfacebook-github-bot
authored andcommitted
fix EBC optimizer size setting for virtual table (#3239)
Summary: Pull Request resolved: #3239 Add missing optimizer state processing logic similar to what EC is doing here: https://fburl.com/code/hvzzs3t4, to make sure optimizer state won't use default metadata which is the virtual table size, not the actual tensor size. Reviewed By: EddyLXJ Differential Revision: D78950717 fbshipit-source-id: 45eca79bbff1fe498e2707b51a6845eb603bbdfd
1 parent 69acf48 commit d945603

File tree

1 file changed

+56
-23
lines changed

1 file changed

+56
-23
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,33 +1081,66 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no
10811081
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
10821082
# access is allowed on them.
10831083

1084-
# create ShardedTensor from local shards and metadata avoding all_gather collective
1085-
sharding_spec = none_throws(
1086-
self.module_sharding_plan[table_name].sharding_spec
1087-
)
1088-
1089-
tensor_properties = TensorProperties(
1090-
dtype=(
1091-
data_type_to_dtype(
1092-
self._table_name_to_config[table_name].data_type
1084+
if self._table_name_to_config[table_name].use_virtual_table:
1085+
# virtual table size will be recalculated before checkpointing. Here we cannot
1086+
# use sharding spec to build tensor metadata which will exceed the checkpoint capacity limit
1087+
self._model_parallel_name_to_sharded_tensor[table_name] = (
1088+
ShardedTensor._init_from_local_shards(
1089+
local_shards,
1090+
(
1091+
[
1092+
# assuming virtual table only supports rw sharding for now
1093+
# When backend return whole row, need to respect dim(1)
1094+
# otherwise will see shard dim exceeded tensor dim error
1095+
(
1096+
0
1097+
if dim == 0
1098+
else (
1099+
local_shards[0].metadata.shard_sizes[1]
1100+
if dim == 1
1101+
else dim_size
1102+
)
1103+
)
1104+
for dim, dim_size in enumerate(
1105+
self._name_to_table_size[table_name]
1106+
)
1107+
]
1108+
),
1109+
process_group=(
1110+
self._env.sharding_pg
1111+
if isinstance(self._env, ShardingEnv2D)
1112+
else self._env.process_group
1113+
),
10931114
)
1094-
),
1095-
)
1115+
)
1116+
else:
1117+
# create ShardedTensor from local shards and metadata avoding all_gather collective
1118+
sharding_spec = none_throws(
1119+
self.module_sharding_plan[table_name].sharding_spec
1120+
)
10961121

1097-
self._model_parallel_name_to_sharded_tensor[table_name] = (
1098-
ShardedTensor._init_from_local_shards_and_global_metadata(
1099-
local_shards=local_shards,
1100-
sharded_tensor_metadata=sharding_spec.build_metadata(
1101-
tensor_sizes=self._name_to_table_size[table_name],
1102-
tensor_properties=tensor_properties,
1103-
),
1104-
process_group=(
1105-
self._env.sharding_pg
1106-
if isinstance(self._env, ShardingEnv2D)
1107-
else self._env.process_group
1122+
tensor_properties = TensorProperties(
1123+
dtype=(
1124+
data_type_to_dtype(
1125+
self._table_name_to_config[table_name].data_type
1126+
)
11081127
),
11091128
)
1110-
)
1129+
1130+
self._model_parallel_name_to_sharded_tensor[table_name] = (
1131+
ShardedTensor._init_from_local_shards_and_global_metadata(
1132+
local_shards=local_shards,
1133+
sharded_tensor_metadata=sharding_spec.build_metadata(
1134+
tensor_sizes=self._name_to_table_size[table_name],
1135+
tensor_properties=tensor_properties,
1136+
),
1137+
process_group=(
1138+
self._env.sharding_pg
1139+
if isinstance(self._env, ShardingEnv2D)
1140+
else self._env.process_group
1141+
),
1142+
)
1143+
)
11111144

11121145
def extract_sharded_kvtensors(
11131146
module: ShardedEmbeddingBagCollection,

0 commit comments

Comments
 (0)