Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sending using lengths to TBE instead of just offsets #2557

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
@@ -446,6 +446,15 @@ def _prefetch_and_cached(
)


def _all_tables_are_quant_kernel(
tables: List[ShardedEmbeddingTable],
) -> bool:
"""
Return if all tables have quant compute kernel.
"""
return all(table.compute_kernel == EmbeddingComputeKernel.QUANT for table in tables)


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
@@ -489,6 +498,8 @@ def _group_tables_per_rank(
# Collect groups
groups = defaultdict(list)
grouping_keys = []
# Assumes all compute kernels within tables are the same
is_inference = _all_tables_are_quant_kernel(embedding_tables)
for table in embedding_tables:
bucketer = (
prefetch_cached_dim_bucketer
@@ -499,12 +510,16 @@ def _group_tables_per_rank(
_get_grouping_fused_params(table.fused_params, table.name) or {}
)
grouping_key = (
table.data_type,
table.data_type if not is_inference else None,
table.pooling,
table.has_feature_processor,
tuple(sorted(group_fused_params.items())),
_get_compute_kernel_type(table.compute_kernel),
bucketer.get_bucket(table.local_cols, table.data_type),
# TODO: Unit test to check if table.data_type affects table grouping
bucketer.get_bucket(
table.local_cols,
table.data_type,
),
_prefetch_and_cached(table),
)
# micromanage the order of we traverse the groups to ensure backwards compatibility
20 changes: 19 additions & 1 deletion torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional

import torch

@@ -24,6 +24,10 @@
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode"

# Force lengths to offsets conversion before TBE lookup. Helps with performance
# with certain ways to split models.
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"


class TBEToRegisterMixIn:
def get_tbes_to_register(
@@ -68,6 +72,18 @@ def fused_param_bounds_check_mode(
return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE]


def fused_param_lengths_to_offsets_lookup(
fused_params: Optional[Dict[str, Any]]
) -> bool:
if (
fused_params is None
or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params
):
return False
else:
return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
@@ -93,5 +109,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)

return fused_params_for_tbe
196 changes: 131 additions & 65 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
)
from torchrec.distributed.fused_params import (
fused_param_bounds_check_mode,
fused_param_lengths_to_offsets_lookup,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
tbe_fused_params,
@@ -171,6 +172,19 @@ def _unwrap_kjt_for_cpu(
return indices, offsets, None


@torch.fx.wrap
def _unwrap_kjt_lengths(
features: KeyedJaggedTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
indices = features.values()
lengths = features.lengths()
return (
indices.int(),
lengths.int(),
features.weights_or_none(),
)


@torch.fx.wrap
def _unwrap_optional_tensor(
tensor: Optional[torch.Tensor],
@@ -180,6 +194,26 @@ def _unwrap_optional_tensor(
return tensor


class IntNBitTableBatchedEmbeddingBagsCodegenWithLength(
IntNBitTableBatchedEmbeddingBagsCodegen
):
def __init__(self, *args: Any, **kwargs: Dict[str, Any]) -> None:
super().__init__(*args, **kwargs)

# pyre-ignore Inconsistent override [14]
def forward(
self,
indices: torch.Tensor,
lengths: torch.Tensor,
per_sample_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().forward(
indices,
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths),
per_sample_weights,
)


class QuantBatchedEmbeddingBag(
BaseBatchedEmbeddingBag[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
@@ -192,6 +226,7 @@ def __init__(
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
data_type_changed: bool = False,
) -> None:
super().__init__(config, pg, device)

@@ -216,40 +251,53 @@ def __init__(
self._runtime_device: torch.device = _get_runtime_device(device, config)
# 16 for CUDA, 1 for others like CPU and MTIA.
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
embedding_specs = []
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
):
embedding_specs.append(
(
table.name,
local_rows,
(
table.name,
local_rows,
(
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(config.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(
# if data_type has changed, we want to default to the up-to-date config.data_type, instead of the embedding_tables which does not have the quantized data type
config.data_type
if data_type_changed
else table.data_type
),
location,
)
)

self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params)

if self.lengths_to_tbe:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
else:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
if device is not None:
self._emb_module.initialize_weights()
@@ -268,44 +316,50 @@ def get_tbes_to_register(
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return {self._emb_module: self._config}

def _emb_module_forward(
self,
indices: torch.Tensor,
lengths_or_offsets: torch.Tensor,
weights: Optional[torch.Tensor],
) -> torch.Tensor:
kwargs = {"indices": indices}

if self._is_weighted:
kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights)

if self.lengths_to_tbe:
kwargs["lengths"] = lengths_or_offsets
else:
kwargs["offsets"] = lengths_or_offsets

if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self._emb_module(**kwargs)
else:
return self._emb_module.forward(**kwargs)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
lengths, offsets = None, None
if self._runtime_device.type == "cpu":
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
indices, offsets, per_sample_weights = _unwrap_kjt(features)

if self._is_weighted:
weights = _unwrap_optional_tensor(per_sample_weights)
if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self.emb_module(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
)
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
if self._emb_module_registered:
return self.emb_module(
indices=indices,
offsets=offsets,
)
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
)
indices, offsets, per_sample_weights = _unwrap_kjt(features)

return self._emb_module_forward(
indices, lengths if lengths is not None else offsets, per_sample_weights
)

def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
@@ -359,8 +413,16 @@ def from_float(
)
device = next(iter(state_dict.values())).device

data_type_changed = False
# qconfig data type can be different from the config data type - if we are quantizing already sharded embeddings.
# This means the embedding_tables within GroupedEmbeddingConfig do not have the up-to-date data type - as they have not yet been quantized
if data_type != module.config.data_type:
data_type_changed = True
# We update the config to have the right data_type, sparse_type, and device. This update does not change the embedding_tables data type
config = _copy_config(module.config, data_type, sparse_type, device)
ret = QuantBatchedEmbeddingBag(config=config, device=device)
ret = QuantBatchedEmbeddingBag(
config=config, device=device, data_type_changed=data_type_changed
)

# pyre-ignore
quant_weight_list = _quantize_weight(state_dict, data_type)
@@ -411,7 +473,11 @@ def __init__(
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(config.data_type),
(
data_type_to_sparse_type(config.data_type)
if config.data_type is not None
else data_type_to_sparse_type(table.data_type)
),
location,
)
for local_rows, local_cols, table, location in zip(
15 changes: 15 additions & 0 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
@@ -429,6 +429,21 @@ def test_should_not_group_together(
)
return

# If both kernels are quantized, we assume this is inference which we no longer split by data_type
# So if other attributes are the same between the two tables (regardless of data type), we combine them
if (
tables[0].compute_kernel == EmbeddingComputeKernel.QUANT
and tables[1].compute_kernel == EmbeddingComputeKernel.QUANT
and tables[0].pooling == tables[1].pooling
and tables[0].has_feature_processor == tables[1].has_feature_processor
):

self.assertEqual(
sorted(_get_table_names_by_groups(tables)),
[["table_0", "table_1"]],
)
return

self.assertEqual(
sorted(_get_table_names_by_groups(tables)),
[["table_0"], ["table_1"]],
2 changes: 2 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import (
FUSED_PARAM_BOUNDS_CHECK_MODE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
)
@@ -82,6 +83,7 @@ def trim_torch_package_prefix_from_typename(typename: str) -> str:
FUSED_PARAM_REGISTER_TBE_BOOL: True,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True,
FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False,
}

DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [
81 changes: 81 additions & 0 deletions torchrec/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -215,3 +215,84 @@ def test_quantize_per_table_dtype(self) -> None:
# 3 TBES (1 FPEBC, 2 EBCs (1 weighted, 1 unweighted))

self.assertEqual(num_tbes, 3)

def test_sharded_quantized_tbe_count(self) -> None:
set_propogate_device(True)

model = TestSparseNN(
tables=self.tables,
weighted_tables=self.weighted_tables,
num_float_features=10,
dense_device=torch.device("cpu"),
sparse_device=torch.device("cpu"),
over_arch_clazz=TestOverArchRegroupModule,
)

per_table_weight_dtypes = {}

for table in self.tables + self.weighted_tables:
# quint4x2 different than int8, which is default
per_table_weight_dtypes[table.name] = torch.quint4x2 if table.name == "table_0" else torch.quint8

model.eval()
_, local_batch = ModelInput.generate(
batch_size=16,
world_size=1,
num_float_features=10,
tables=self.tables,
weighted_tables=self.weighted_tables,
)

# with torch.inference_mode(): # TODO: Why does inference mode fail when using different quant data types
output = model(local_batch[0])

# Quantize the model and collect quantized weights
quantized_model = quantize_inference_model(model, per_table_weight_dtype=per_table_weight_dtypes)
quantized_output = quantized_model(local_batch[0])
table_to_weight = get_table_to_weights_from_tbe(quantized_model)

# Shard the model, all weights are initialized back to 0, so have to reassign weights
sharded_quant_model, _ = shard_quant_model(
quantized_model,
world_size=1,
compute_device="cpu",
sharding_device="cpu",
)
assign_weights_to_tbe(quantized_model, table_to_weight)
sharded_quant_output = sharded_quant_model(local_batch[0])

# When world_size = 1, we should have 1 TBE per sharded, quantized ebc
self.assertTrue(len(sharded_quant_model.sparse.ebc.tbes) == 1)
self.assertTrue(len(sharded_quant_model.sparse.weighted_ebc.tbes) == 1)

# Check the weights are close
self.assertTrue(torch.allclose(output, quantized_output, atol=1e-3))
self.assertTrue(torch.allclose(output, sharded_quant_output, atol=1e-3))

# Check the sizes are correct
expected_num_embeddings = {}

for table in self.tables:
expected_num_embeddings[table.name] = table.num_embeddings

for module in quantized_model.modules():
if module.__class__.__name__ == "IntNBitTableBatchedEmbeddingBagsCodegen":
for i, spec in enumerate(module.embedding_specs):
if spec[0] in expected_num_embeddings:
# We only expect the first table to be quantized to int4 due to test set up
if spec[0] == "table_0":
self.assertEqual(spec[3], SparseType.INT4)
else:
self.assertEqual(spec[3], SparseType.INT8)

# Check sizes are equal
self.assertEqual(
module.split_embedding_weights()[i][0].size(0),
expected_num_embeddings[spec[0]],
)
self.assertEqual(
spec[1],
expected_num_embeddings[spec[0]],
)