@@ -764,32 +764,37 @@ def __init__(
764764 self ._weight_init_mins : List [float ] = []
765765 self ._weight_init_maxs : List [float ] = []
766766 self ._num_embeddings : List [int ] = []
767+ self ._embedding_dims : List [int ] = []
767768 self ._local_cols : List [int ] = []
769+ self ._row_offset : List [int ] = []
770+ self ._col_offset : List [int ] = []
768771 self ._feature_table_map : List [int ] = []
769772 self .table_name_to_count : Dict [str , int ] = {}
770773 self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
771774
772- # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
773- # `ShardedEmbeddingTable`.
774- for idx , config in enumerate (self ._config .embedding_tables ):
775- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
776- self ._local_rows .append (config .local_rows )
777- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
778- # `get_weight_init_min`.
779- self ._weight_init_mins .append (config .get_weight_init_min ())
780- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
781- # `get_weight_init_max`.
782- self ._weight_init_maxs .append (config .get_weight_init_max ())
783- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
784- # `num_embeddings`.
785- self ._num_embeddings .append (config .num_embeddings )
786- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
787- self ._local_cols .append (config .local_cols )
788- self ._feature_table_map .extend ([idx ] * config .num_features ())
789- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
790- if config .name not in self .table_name_to_count :
791- self .table_name_to_count [config .name ] = 0
792- self .table_name_to_count [config .name ] += 1
775+ for idx , table_config in enumerate (self ._config .embedding_tables ):
776+ self ._local_rows .append (table_config .local_rows )
777+ self ._weight_init_mins .append (table_config .get_weight_init_min ())
778+ self ._weight_init_maxs .append (table_config .get_weight_init_max ())
779+ self ._num_embeddings .append (table_config .num_embeddings )
780+ self ._embedding_dims .append (table_config .embedding_dim )
781+ self ._row_offset .append (
782+ table_config .local_metadata .shard_offsets [0 ]
783+ if table_config .local_metadata
784+ and len (table_config .local_metadata .shard_offsets ) > 0
785+ else 0
786+ )
787+ self ._col_offset .append (
788+ table_config .local_metadata .shard_offsets [1 ]
789+ if table_config .local_metadata
790+ and len (table_config .local_metadata .shard_offsets ) > 1
791+ else 0
792+ )
793+ self ._local_cols .append (table_config .local_cols )
794+ self ._feature_table_map .extend ([idx ] * table_config .num_features ())
795+ if table_config .name not in self .table_name_to_count :
796+ self .table_name_to_count [table_config .name ] = 0
797+ self .table_name_to_count [table_config .name ] += 1
793798
794799 def init_parameters (self ) -> None :
795800 # initialize embedding weights
@@ -1080,6 +1085,14 @@ def __init__(
10801085 weights_precision = weights_precision ,
10811086 device = device ,
10821087 table_names = [t .name for t in config .embedding_tables ],
1088+ embedding_shard_info = list (
1089+ zip (
1090+ self ._num_embeddings ,
1091+ self ._embedding_dims ,
1092+ self ._row_offset ,
1093+ self ._col_offset ,
1094+ )
1095+ ),
10831096 ** fused_params ,
10841097 )
10851098 )
@@ -1216,34 +1229,39 @@ def __init__(
12161229 self ._weight_init_mins : List [float ] = []
12171230 self ._weight_init_maxs : List [float ] = []
12181231 self ._num_embeddings : List [int ] = []
1232+ self ._embedding_dims : List [int ] = []
12191233 self ._local_cols : List [int ] = []
1234+ self ._row_offset : List [int ] = []
1235+ self ._col_offset : List [int ] = []
12201236 self ._feature_table_map : List [int ] = []
12211237 self ._emb_names : List [str ] = []
12221238 self ._lengths_per_emb : List [int ] = []
12231239 self .table_name_to_count : Dict [str , int ] = {}
12241240 self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
12251241
1226- # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1227- # `ShardedEmbeddingTable`.
1228- for idx , config in enumerate (self ._config .embedding_tables ):
1229- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
1230- self ._local_rows .append (config .local_rows )
1231- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1232- # `get_weight_init_min`.
1233- self ._weight_init_mins .append (config .get_weight_init_min ())
1234- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1235- # `get_weight_init_max`.
1236- self ._weight_init_maxs .append (config .get_weight_init_max ())
1237- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1238- # `num_embeddings`.
1239- self ._num_embeddings .append (config .num_embeddings )
1240- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
1241- self ._local_cols .append (config .local_cols )
1242- self ._feature_table_map .extend ([idx ] * config .num_features ())
1243- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
1244- if config .name not in self .table_name_to_count :
1245- self .table_name_to_count [config .name ] = 0
1246- self .table_name_to_count [config .name ] += 1
1242+ for idx , table_config in enumerate (self ._config .embedding_tables ):
1243+ self ._local_rows .append (table_config .local_rows )
1244+ self ._weight_init_mins .append (table_config .get_weight_init_min ())
1245+ self ._weight_init_maxs .append (table_config .get_weight_init_max ())
1246+ self ._num_embeddings .append (table_config .num_embeddings )
1247+ self ._embedding_dims .append (table_config .embedding_dim )
1248+ self ._row_offset .append (
1249+ table_config .local_metadata .shard_offsets [0 ]
1250+ if table_config .local_metadata
1251+ and len (table_config .local_metadata .shard_offsets ) > 0
1252+ else 0
1253+ )
1254+ self ._col_offset .append (
1255+ table_config .local_metadata .shard_offsets [1 ]
1256+ if table_config .local_metadata
1257+ and len (table_config .local_metadata .shard_offsets ) > 1
1258+ else 0
1259+ )
1260+ self ._local_cols .append (table_config .local_cols )
1261+ self ._feature_table_map .extend ([idx ] * table_config .num_features ())
1262+ if table_config .name not in self .table_name_to_count :
1263+ self .table_name_to_count [table_config .name ] = 0
1264+ self .table_name_to_count [table_config .name ] += 1
12471265
12481266 def init_parameters (self ) -> None :
12491267 # initialize embedding weights
@@ -1564,6 +1582,14 @@ def __init__(
15641582 weights_precision = weights_precision ,
15651583 device = device ,
15661584 table_names = [t .name for t in config .embedding_tables ],
1585+ embedding_shard_info = list (
1586+ zip (
1587+ self ._num_embeddings ,
1588+ self ._embedding_dims ,
1589+ self ._row_offset ,
1590+ self ._col_offset ,
1591+ )
1592+ ),
15671593 ** fused_params ,
15681594 )
15691595 )
0 commit comments