Skip to content

Commit 75f1f1c

Browse files
levythufacebook-github-bot
authored andcommitted
Propagate proper embedding_shard_info when constructing TBE (#2876)
Summary: Pull Request resolved: #2876 This will pass in the right information about shard context to TBE's new parameter. It also did a bit of variable renaming to avoid name shadowing and remove pyre suppression Reviewed By: ge0405 Differential Revision: D72421135 fbshipit-source-id: 08ed16cef699ad13d42cc34442cb579d8a110edc
1 parent 05aea06 commit 75f1f1c

File tree

1 file changed

+68
-42
lines changed

1 file changed

+68
-42
lines changed

torchrec/distributed/batched_embedding_kernel.py

+68-42
Original file line numberDiff line numberDiff line change
@@ -764,32 +764,37 @@ def __init__(
764764
self._weight_init_mins: List[float] = []
765765
self._weight_init_maxs: List[float] = []
766766
self._num_embeddings: List[int] = []
767+
self._embedding_dims: List[int] = []
767768
self._local_cols: List[int] = []
769+
self._row_offset: List[int] = []
770+
self._col_offset: List[int] = []
768771
self._feature_table_map: List[int] = []
769772
self.table_name_to_count: Dict[str, int] = {}
770773
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
771774

772-
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
773-
# `ShardedEmbeddingTable`.
774-
for idx, config in enumerate(self._config.embedding_tables):
775-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
776-
self._local_rows.append(config.local_rows)
777-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
778-
# `get_weight_init_min`.
779-
self._weight_init_mins.append(config.get_weight_init_min())
780-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
781-
# `get_weight_init_max`.
782-
self._weight_init_maxs.append(config.get_weight_init_max())
783-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
784-
# `num_embeddings`.
785-
self._num_embeddings.append(config.num_embeddings)
786-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
787-
self._local_cols.append(config.local_cols)
788-
self._feature_table_map.extend([idx] * config.num_features())
789-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
790-
if config.name not in self.table_name_to_count:
791-
self.table_name_to_count[config.name] = 0
792-
self.table_name_to_count[config.name] += 1
775+
for idx, table_config in enumerate(self._config.embedding_tables):
776+
self._local_rows.append(table_config.local_rows)
777+
self._weight_init_mins.append(table_config.get_weight_init_min())
778+
self._weight_init_maxs.append(table_config.get_weight_init_max())
779+
self._num_embeddings.append(table_config.num_embeddings)
780+
self._embedding_dims.append(table_config.embedding_dim)
781+
self._row_offset.append(
782+
table_config.local_metadata.shard_offsets[0]
783+
if table_config.local_metadata
784+
and len(table_config.local_metadata.shard_offsets) > 0
785+
else 0
786+
)
787+
self._col_offset.append(
788+
table_config.local_metadata.shard_offsets[1]
789+
if table_config.local_metadata
790+
and len(table_config.local_metadata.shard_offsets) > 1
791+
else 0
792+
)
793+
self._local_cols.append(table_config.local_cols)
794+
self._feature_table_map.extend([idx] * table_config.num_features())
795+
if table_config.name not in self.table_name_to_count:
796+
self.table_name_to_count[table_config.name] = 0
797+
self.table_name_to_count[table_config.name] += 1
793798

794799
def init_parameters(self) -> None:
795800
# initialize embedding weights
@@ -1080,6 +1085,14 @@ def __init__(
10801085
weights_precision=weights_precision,
10811086
device=device,
10821087
table_names=[t.name for t in config.embedding_tables],
1088+
embedding_shard_info=list(
1089+
zip(
1090+
self._num_embeddings,
1091+
self._embedding_dims,
1092+
self._row_offset,
1093+
self._col_offset,
1094+
)
1095+
),
10831096
**fused_params,
10841097
)
10851098
)
@@ -1216,34 +1229,39 @@ def __init__(
12161229
self._weight_init_mins: List[float] = []
12171230
self._weight_init_maxs: List[float] = []
12181231
self._num_embeddings: List[int] = []
1232+
self._embedding_dims: List[int] = []
12191233
self._local_cols: List[int] = []
1234+
self._row_offset: List[int] = []
1235+
self._col_offset: List[int] = []
12201236
self._feature_table_map: List[int] = []
12211237
self._emb_names: List[str] = []
12221238
self._lengths_per_emb: List[int] = []
12231239
self.table_name_to_count: Dict[str, int] = {}
12241240
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = {}
12251241

1226-
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1227-
# `ShardedEmbeddingTable`.
1228-
for idx, config in enumerate(self._config.embedding_tables):
1229-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
1230-
self._local_rows.append(config.local_rows)
1231-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1232-
# `get_weight_init_min`.
1233-
self._weight_init_mins.append(config.get_weight_init_min())
1234-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1235-
# `get_weight_init_max`.
1236-
self._weight_init_maxs.append(config.get_weight_init_max())
1237-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1238-
# `num_embeddings`.
1239-
self._num_embeddings.append(config.num_embeddings)
1240-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
1241-
self._local_cols.append(config.local_cols)
1242-
self._feature_table_map.extend([idx] * config.num_features())
1243-
# pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
1244-
if config.name not in self.table_name_to_count:
1245-
self.table_name_to_count[config.name] = 0
1246-
self.table_name_to_count[config.name] += 1
1242+
for idx, table_config in enumerate(self._config.embedding_tables):
1243+
self._local_rows.append(table_config.local_rows)
1244+
self._weight_init_mins.append(table_config.get_weight_init_min())
1245+
self._weight_init_maxs.append(table_config.get_weight_init_max())
1246+
self._num_embeddings.append(table_config.num_embeddings)
1247+
self._embedding_dims.append(table_config.embedding_dim)
1248+
self._row_offset.append(
1249+
table_config.local_metadata.shard_offsets[0]
1250+
if table_config.local_metadata
1251+
and len(table_config.local_metadata.shard_offsets) > 0
1252+
else 0
1253+
)
1254+
self._col_offset.append(
1255+
table_config.local_metadata.shard_offsets[1]
1256+
if table_config.local_metadata
1257+
and len(table_config.local_metadata.shard_offsets) > 1
1258+
else 0
1259+
)
1260+
self._local_cols.append(table_config.local_cols)
1261+
self._feature_table_map.extend([idx] * table_config.num_features())
1262+
if table_config.name not in self.table_name_to_count:
1263+
self.table_name_to_count[table_config.name] = 0
1264+
self.table_name_to_count[table_config.name] += 1
12471265

12481266
def init_parameters(self) -> None:
12491267
# initialize embedding weights
@@ -1564,6 +1582,14 @@ def __init__(
15641582
weights_precision=weights_precision,
15651583
device=device,
15661584
table_names=[t.name for t in config.embedding_tables],
1585+
embedding_shard_info=list(
1586+
zip(
1587+
self._num_embeddings,
1588+
self._embedding_dims,
1589+
self._row_offset,
1590+
self._col_offset,
1591+
)
1592+
),
15671593
**fused_params,
15681594
)
15691595
)

0 commit comments

Comments
 (0)