Skip to content

Commit 4b5a8a3

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Remove QuantizeEBC grouping by data type (#2571)
Summary: Pull Request resolved: #2571 FBGEMM supports TBEs being initialized with a list of data_types, so we no longer have to split tables with different quantization data types into separate tables. We've already implemented this for the ShardedQuantizedEBC - see base diff - which helped to optimize TorchRec eager mode Inference to have on-par QPS with non-eager mode inference. This diff introduces the same optimization to QuantizedEBC - even though it's not needed, for consistency. Reviewed By: PaulZhang12 Differential Revision: D63861064 fbshipit-source-id: 376adbce6cde5d9c45def0836c78c1455993d419
1 parent 50889bd commit 4b5a8a3

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

torchrec/inference/tests/test_inference.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from fbgemm_gpu.split_embedding_configs import SparseType
16+
from torchrec import PoolingType
1617
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
1718
from torchrec.distributed.global_settings import set_propogate_device
1819
from torchrec.distributed.test_utils.test_model import (
@@ -298,3 +299,35 @@ def test_sharded_quantized_tbe_count(self) -> None:
298299
spec[1],
299300
expected_num_embeddings[spec[0]],
300301
)
302+
303+
def test_quantized_tbe_count_different_pooling(self) -> None:
304+
set_propogate_device(True)
305+
306+
self.tables[0].pooling = PoolingType.MEAN
307+
model = TestSparseNN(
308+
tables=self.tables,
309+
weighted_tables=self.weighted_tables,
310+
num_float_features=10,
311+
dense_device=torch.device("cpu"),
312+
sparse_device=torch.device("cpu"),
313+
over_arch_clazz=TestOverArchRegroupModule,
314+
)
315+
316+
model.eval()
317+
_, local_batch = ModelInput.generate(
318+
batch_size=16,
319+
world_size=1,
320+
num_float_features=10,
321+
tables=self.tables,
322+
weighted_tables=self.weighted_tables,
323+
)
324+
325+
model(local_batch[0])
326+
327+
# Quantize the model and collect quantized weights
328+
quantized_model = quantize_inference_model(model)
329+
# We should have 2 TBEs for unweighted ebc as the 2 tables here have different pooling types
330+
self.assertTrue(len(quantized_model.sparse.ebc.tbes) == 2)
331+
self.assertTrue(len(quantized_model.sparse.weighted_ebc.tbes) == 1)
332+
# Changing this back
333+
self.tables[0].pooling = PoolingType.SUM

torchrec/quant/embedding_modules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,14 @@ def __init__(
382382
if table.name in table_names:
383383
raise ValueError(f"Duplicate table name {table.name}")
384384
table_names.add(table.name)
385-
key = (table.pooling, table.data_type)
386-
self._key_to_tables[key].append(table)
385+
# pyre-ignore
386+
self._key_to_tables[table.pooling].append(table)
387387

388388
location = (
389389
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
390390
)
391391

392-
for key, emb_configs in self._key_to_tables.items():
393-
(pooling, data_type) = key
392+
for pooling, emb_configs in self._key_to_tables.items():
394393
embedding_specs = []
395394
weight_lists: Optional[
396395
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -409,7 +408,7 @@ def __init__(
409408
else table.num_embeddings
410409
),
411410
table.embedding_dim,
412-
data_type_to_sparse_type(data_type),
411+
data_type_to_sparse_type(table.data_type),
413412
location,
414413
)
415414
)
@@ -421,6 +420,7 @@ def __init__(
421420

422421
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
423422
embedding_specs=embedding_specs,
423+
# pyre-ignore
424424
pooling_mode=pooling_type_to_pooling_mode(pooling),
425425
weight_lists=weight_lists,
426426
device=device,

0 commit comments

Comments
 (0)