@@ -759,6 +759,13 @@ def __init__(
759
759
self ._feature_table_map : List [int ] = []
760
760
self .table_name_to_count : Dict [str , int ] = {}
761
761
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
762
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
763
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
764
+ "embedding_table_index_type" , torch .int64
765
+ )
766
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
767
+ "embedding_table_offset_type" , torch .int64
768
+ )
762
769
763
770
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
764
771
# `ShardedEmbeddingTable`.
@@ -800,8 +807,16 @@ def init_parameters(self) -> None:
800
807
801
808
def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
802
809
return self .emb_module (
803
- indices = features .values ().long (),
804
- offsets = features .offsets ().long (),
810
+ indices = (
811
+ features .values ()
812
+ if self ._embedding_table_index_type == torch .int32
813
+ else features .values ().long ()
814
+ ),
815
+ offsets = (
816
+ features .offsets ().type (dtype = features .values ().dtype )
817
+ if self ._embedding_table_offset_type == torch .int32
818
+ else features .offsets ().long ()
819
+ ),
805
820
)
806
821
807
822
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -1213,6 +1228,13 @@ def __init__(
1213
1228
self ._lengths_per_emb : List [int ] = []
1214
1229
self .table_name_to_count : Dict [str , int ] = {}
1215
1230
self ._param_per_table : Dict [str , TableBatchedEmbeddingSlice ] = {}
1231
+ self ._fused_params : Dict [str , Any ] = config .fused_params or {}
1232
+ self ._embedding_table_index_type : torch .dtype = self ._fused_params .get (
1233
+ "embedding_table_index_type" , torch .int64
1234
+ )
1235
+ self ._embedding_table_offset_type : torch .dtype = self ._fused_params .get (
1236
+ "embedding_table_offset_type" , torch .int64
1237
+ )
1216
1238
1217
1239
# pyre-fixme[9]: config has type `GroupedEmbeddingConfig`; used as
1218
1240
# `ShardedEmbeddingTable`.
@@ -1265,15 +1287,31 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1265
1287
),
1266
1288
):
1267
1289
return self .emb_module (
1268
- indices = features .values ().long (),
1269
- offsets = features .offsets ().long (),
1290
+ indices = (
1291
+ features .values ()
1292
+ if self ._embedding_table_index_type == torch .int32
1293
+ else features .values ().long ()
1294
+ ),
1295
+ offsets = (
1296
+ features .offsets ().type (dtype = features .values ().dtype )
1297
+ if self ._embedding_table_offset_type == torch .int32
1298
+ else features .offsets ().long ()
1299
+ ),
1270
1300
per_sample_weights = weights ,
1271
1301
batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
1272
1302
)
1273
1303
else :
1274
1304
return self .emb_module (
1275
- indices = features .values ().long (),
1276
- offsets = features .offsets ().long (),
1305
+ indices = (
1306
+ features .values ()
1307
+ if self ._embedding_table_index_type == torch .int32
1308
+ else features .values ().long ()
1309
+ ),
1310
+ offsets = (
1311
+ features .offsets ().type (dtype = features .values ().dtype )
1312
+ if self ._embedding_table_offset_type == torch .int32
1313
+ else features .offsets ().long ()
1314
+ ),
1277
1315
per_sample_weights = weights ,
1278
1316
)
1279
1317
0 commit comments