@@ -764,32 +764,37 @@ def __init__(
764
764
self ._weight_init_mins : List [float ] = []
765
765
self ._weight_init_maxs : List [float ] = []
766
766
self ._num_embeddings : List [int ] = []
767
+ self ._embedding_dims : List [int ] = []
767
768
self ._local_cols : List [int ] = []
769
+ self ._row_offset : List [int ] = []
770
+ self ._col_offset : List [int ] = []
768
771
self ._feature_table_map : List [int ] = []
769
772
self .table_name_to_count : Dict [str , int ] = {}
770
773
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
771
774
772
- # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
773
- # `ShardedEmbeddingTable`.
774
- for idx , config in enumerate (self ._config .embedding_tables ):
775
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
776
- self ._local_rows .append (config .local_rows )
777
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
778
- # `get_weight_init_min`.
779
- self ._weight_init_mins .append (config .get_weight_init_min ())
780
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
781
- # `get_weight_init_max`.
782
- self ._weight_init_maxs .append (config .get_weight_init_max ())
783
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
784
- # `num_embeddings`.
785
- self ._num_embeddings .append (config .num_embeddings )
786
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
787
- self ._local_cols .append (config .local_cols )
788
- self ._feature_table_map .extend ([idx ] * config .num_features ())
789
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
790
- if config .name not in self .table_name_to_count :
791
- self .table_name_to_count [config .name ] = 0
792
- self .table_name_to_count [config .name ] += 1
775
+ for idx , table_config in enumerate (self ._config .embedding_tables ):
776
+ self ._local_rows .append (table_config .local_rows )
777
+ self ._weight_init_mins .append (table_config .get_weight_init_min ())
778
+ self ._weight_init_maxs .append (table_config .get_weight_init_max ())
779
+ self ._num_embeddings .append (table_config .num_embeddings )
780
+ self ._embedding_dims .append (table_config .embedding_dim )
781
+ self ._row_offset .append (
782
+ table_config .local_metadata .shard_offsets [0 ]
783
+ if table_config .local_metadata
784
+ and len (table_config .local_metadata .shard_offsets ) > 0
785
+ else 0
786
+ )
787
+ self ._col_offset .append (
788
+ table_config .local_metadata .shard_offsets [1 ]
789
+ if table_config .local_metadata
790
+ and len (table_config .local_metadata .shard_offsets ) > 1
791
+ else 0
792
+ )
793
+ self ._local_cols .append (table_config .local_cols )
794
+ self ._feature_table_map .extend ([idx ] * table_config .num_features ())
795
+ if table_config .name not in self .table_name_to_count :
796
+ self .table_name_to_count [table_config .name ] = 0
797
+ self .table_name_to_count [table_config .name ] += 1
793
798
794
799
def init_parameters (self ) -> None :
795
800
# initialize embedding weights
@@ -1080,6 +1085,14 @@ def __init__(
1080
1085
weights_precision = weights_precision ,
1081
1086
device = device ,
1082
1087
table_names = [t .name for t in config .embedding_tables ],
1088
+ embedding_shard_info = list (
1089
+ zip (
1090
+ self ._num_embeddings ,
1091
+ self ._embedding_dims ,
1092
+ self ._row_offset ,
1093
+ self ._col_offset ,
1094
+ )
1095
+ ),
1083
1096
** fused_params ,
1084
1097
)
1085
1098
)
@@ -1216,34 +1229,39 @@ def __init__(
1216
1229
self ._weight_init_mins : List [float ] = []
1217
1230
self ._weight_init_maxs : List [float ] = []
1218
1231
self ._num_embeddings : List [int ] = []
1232
+ self ._embedding_dims : List [int ] = []
1219
1233
self ._local_cols : List [int ] = []
1234
+ self ._row_offset : List [int ] = []
1235
+ self ._col_offset : List [int ] = []
1220
1236
self ._feature_table_map : List [int ] = []
1221
1237
self ._emb_names : List [str ] = []
1222
1238
self ._lengths_per_emb : List [int ] = []
1223
1239
self .table_name_to_count : Dict [str , int ] = {}
1224
1240
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1225
1241
1226
- # pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1227
- # `ShardedEmbeddingTable`.
1228
- for idx , config in enumerate (self ._config .embedding_tables ):
1229
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_rows`.
1230
- self ._local_rows .append (config .local_rows )
1231
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1232
- # `get_weight_init_min`.
1233
- self ._weight_init_mins .append (config .get_weight_init_min ())
1234
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1235
- # `get_weight_init_max`.
1236
- self ._weight_init_maxs .append (config .get_weight_init_max ())
1237
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute
1238
- # `num_embeddings`.
1239
- self ._num_embeddings .append (config .num_embeddings )
1240
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `local_cols`.
1241
- self ._local_cols .append (config .local_cols )
1242
- self ._feature_table_map .extend ([idx ] * config .num_features ())
1243
- # pyre-fixme[16]: `GroupedEmbeddingConfig` has no attribute `name`.
1244
- if config .name not in self .table_name_to_count :
1245
- self .table_name_to_count [config .name ] = 0
1246
- self .table_name_to_count [config .name ] += 1
1242
+ for idx , table_config in enumerate (self ._config .embedding_tables ):
1243
+ self ._local_rows .append (table_config .local_rows )
1244
+ self ._weight_init_mins .append (table_config .get_weight_init_min ())
1245
+ self ._weight_init_maxs .append (table_config .get_weight_init_max ())
1246
+ self ._num_embeddings .append (table_config .num_embeddings )
1247
+ self ._embedding_dims .append (table_config .embedding_dim )
1248
+ self ._row_offset .append (
1249
+ table_config .local_metadata .shard_offsets [0 ]
1250
+ if table_config .local_metadata
1251
+ and len (table_config .local_metadata .shard_offsets ) > 0
1252
+ else 0
1253
+ )
1254
+ self ._col_offset .append (
1255
+ table_config .local_metadata .shard_offsets [1 ]
1256
+ if table_config .local_metadata
1257
+ and len (table_config .local_metadata .shard_offsets ) > 1
1258
+ else 0
1259
+ )
1260
+ self ._local_cols .append (table_config .local_cols )
1261
+ self ._feature_table_map .extend ([idx ] * table_config .num_features ())
1262
+ if table_config .name not in self .table_name_to_count :
1263
+ self .table_name_to_count [table_config .name ] = 0
1264
+ self .table_name_to_count [table_config .name ] += 1
1247
1265
1248
1266
def init_parameters (self ) -> None :
1249
1267
# initialize embedding weights
@@ -1564,6 +1582,14 @@ def __init__(
1564
1582
weights_precision = weights_precision ,
1565
1583
device = device ,
1566
1584
table_names = [t .name for t in config .embedding_tables ],
1585
+ embedding_shard_info = list (
1586
+ zip (
1587
+ self ._num_embeddings ,
1588
+ self ._embedding_dims ,
1589
+ self ._row_offset ,
1590
+ self ._col_offset ,
1591
+ )
1592
+ ),
1567
1593
** fused_params ,
1568
1594
)
1569
1595
)
0 commit comments