@@ -764,6 +764,13 @@ def __init__(
764
764
self ._feature_table_map : List [int ] = []
765
765
self .table_name_to_count : Dict [str , int ] = {}
766
766
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
767
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
768
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
769
+ "embedding_table_index_type" , torch .int64
770
+ )
771
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
772
+ "embedding_table_offset_type" , torch .int64
773
+ )
767
774
768
775
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
769
776
# `ShardedEmbeddingTable`.
@@ -805,8 +812,16 @@ def init_parameters(self) -> None:
805
812
806
813
def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
807
814
return self .emb_module (
808
- indices = features .values ().long (),
809
- offsets = features .offsets ().long (),
815
+ indices = (
816
+ features .values ()
817
+ if self ._embedding_table_index_type == torch .int32
818
+ else features .values ().long ()
819
+ ),
820
+ offsets = (
821
+ features .offsets ().type (dtype = features .values ().dtype )
822
+ if self ._embedding_table_offset_type == torch .int32
823
+ else features .offsets ().long ()
824
+ ),
810
825
)
811
826
812
827
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -1218,6 +1233,13 @@ def __init__(
1218
1233
self ._lengths_per_emb : List [int ] = []
1219
1234
self .table_name_to_count : Dict [str , int ] = {}
1220
1235
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1236
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
1237
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
1238
+ "embedding_table_index_type" , torch .int64
1239
+ )
1240
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
1241
+ "embedding_table_offset_type" , torch .int64
1242
+ )
1221
1243
1222
1244
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1223
1245
# `ShardedEmbeddingTable`.
@@ -1270,15 +1292,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1270
1292
),
1271
1293
):
1272
1294
return self .emb_module (
1273
- indices = features .values ().long (),
1274
- offsets = features .offsets ().long (),
1295
+ indices = (
1296
+ features .values ()
1297
+ if self ._embedding_table_index_type == torch .int32
1298
+ else features .values ().long ()
1299
+ ),
1300
+ offsets = (
1301
+ features .offsets ().type (dtype = features .values ().dtype )
1302
+ if self ._embedding_table_offset_type == torch .int32
1303
+ else features .offsets ().long ()
1304
+ ),
1275
1305
per_sample_weights = weights ,
1276
1306
batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1277
1307
)
1278
1308
else :
1279
1309
return self .emb_module (
1280
- indices = features .values ().long (),
1281
- offsets = features .offsets ().long (),
1310
+ indices = (
1311
+ features .values ()
1312
+ if self ._embedding_table_index_type == torch .int32
1313
+ else features .values ().long ()
1314
+ ),
1315
+ offsets = (
1316
+ features .offsets ().type (dtype = features .values ().dtype )
1317
+ if self ._embedding_table_offset_type == torch .int32
1318
+ else features .offsets ().long ()
1319
+ ),
1282
1320
per_sample_weights = weights ,
1283
1321
)
1284
1322
0 commit comments