diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 04afb8fd9..38bb0dd4b 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -132,6 +132,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( bucketize_pos: bool = False, block_bucketize_pos: Optional[List[torch.Tensor]] = None, total_num_blocks: Optional[torch.Tensor] = None, + keep_original_indices: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -159,6 +160,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( max_B=_fx_wrap_max_B(kjt), block_bucketize_pos=block_bucketize_pos, return_bucket_mapping=True, + keep_orig_idx=keep_original_indices, ) return ( @@ -305,6 +307,7 @@ def bucketize_kjt_inference( bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, is_sequence: bool = False, + keep_original_indices: bool = False, ) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets, @@ -352,6 +355,7 @@ def bucketize_kjt_inference( total_num_blocks=total_num_buckets_new_type, bucketize_pos=bucketize_pos, block_bucketize_pos=block_bucketize_row_pos, + keep_original_indices=keep_original_indices, ) else: ( diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index b85e6f9c3..63dbb7a13 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -1171,6 +1171,7 @@ def _create_input_dists( has_feature_processor=sharding._has_feature_processor, need_pos=False, embedding_shard_metadata=emb_sharding, + keep_original_indices=True, ) self._input_dists.append(input_dist) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index deac8359b..f61ea0bd8 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -649,10 +649,12 @@ def __init__( has_feature_processor: bool = False, need_pos: bool = False, embedding_shard_metadata: Optional[List[List[int]]] = None, + keep_original_indices: bool = False, ) -> None: super().__init__() logger.info( f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}" + f", keep_original_indices={keep_original_indices}" ) self._world_size: int = world_size self._num_features = num_features @@ -683,6 +685,7 @@ def __init__( self._embedding_shard_metadata: Optional[List[List[int]]] = ( embedding_shard_metadata ) + self._keep_original_indices = keep_original_indices def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device( @@ -717,6 +720,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: block_bucketize_row_pos ), is_sequence=self._is_sequence, + keep_original_indices=self._keep_original_indices, ) # KJTOneToAll dist_kjt = self._dist.forward(bucketized_features) diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index ce30e3026..a971019ce 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -236,11 +236,6 @@ def validate_batch_size_stages( if len(batch_size_stages) == 0: raise ValueError("Batch size stages should not be empty") - for i in range(len(batch_size_stages) - 1): - if batch_size_stages[i].batch_size >= batch_size_stages[i + 1].batch_size: - raise ValueError( - f"Batch size should be in ascending order. Got {batch_size_stages}" - ) if batch_size_stages[-1].max_iters is not None: raise ValueError( f"Batch size stages last stage should have max_iters = None, but get {batch_size_stages[-1].max_iters}" diff --git a/torchrec/metrics/segmented_ne.py b/torchrec/metrics/segmented_ne.py index cf0f01a3d..e7fc3c7d6 100644 --- a/torchrec/metrics/segmented_ne.py +++ b/torchrec/metrics/segmented_ne.py @@ -165,6 +165,9 @@ class SegmentedNEMetricComputation(RecMetricComputation): Args: include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE. + num_groups (int): number of groups to segment NE by. + grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64. + cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32. """ def __init__( @@ -172,11 +175,15 @@ def __init__( *args: Any, include_logloss: bool = False, # TODO - include num_groups: int = 1, + grouping_keys: str = "grouping_keys", + cast_keys_to_int: bool = False, **kwargs: Any, ) -> None: self._include_logloss: bool = include_logloss super().__init__(*args, **kwargs) self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state + self._grouping_keys = grouping_keys + self._cast_keys_to_int = cast_keys_to_int self._add_state( "cross_entropy_sum", torch.zeros((self._n_tasks, num_groups), dtype=torch.double), @@ -217,21 +224,30 @@ def update( ) -> None: if predictions is None or weights is None: raise RecMetricException( - "Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update" + f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update" ) elif ( "required_inputs" not in kwargs - or kwargs["required_inputs"].get("grouping_keys") is None + or kwargs["required_inputs"].get(self._grouping_keys) is None ): raise RecMetricException( - f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}" - ) - elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64: - raise RecMetricException( - f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}." + f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}" ) + elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64: + if ( + self._cast_keys_to_int + and kwargs["required_inputs"][self._grouping_keys].dtype + == torch.float32 + ): + kwargs["required_inputs"][self._grouping_keys] = kwargs[ + "required_inputs" + ][self._grouping_keys].to(torch.int64) + else: + raise RecMetricException( + f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}." + ) - grouping_keys = kwargs["required_inputs"]["grouping_keys"] + grouping_keys = kwargs["required_inputs"][self._grouping_keys] states = get_segemented_ne_states( labels, predictions, @@ -325,4 +341,8 @@ def __init__( process_group=process_group, **kwargs, ) - self._required_inputs.add("grouping_keys") + if "grouping_keys" not in kwargs: + self._required_inputs.add("grouping_keys") + else: + # pyre-ignore[6] + self._required_inputs.add(kwargs["grouping_keys"]) diff --git a/torchrec/metrics/tests/test_segmented_ne.py b/torchrec/metrics/tests/test_segmented_ne.py index 91a70d9e9..507a7cc8f 100644 --- a/torchrec/metrics/tests/test_segmented_ne.py +++ b/torchrec/metrics/tests/test_segmented_ne.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import Dict, Iterable, Union +from typing import Any, Dict, Iterable, Union import torch from torch import no_grad @@ -31,6 +31,8 @@ def _test_segemented_ne_helper( weights: torch.Tensor, expected_ne: torch.Tensor, grouping_keys: torch.Tensor, + grouping_key_tensor_name: str = "grouping_keys", + cast_keys_to_int: bool = False, ) -> None: num_task = labels.shape[0] batch_size = labels.shape[0] @@ -41,7 +43,7 @@ def _test_segemented_ne_helper( "weights": {}, } if grouping_keys is not None: - inputs["required_inputs"] = {"grouping_keys": grouping_keys} + inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys} for i in range(num_task): task_info = RecTaskInfo( name=f"Task:{i}", @@ -64,6 +66,10 @@ def _test_segemented_ne_helper( tasks=task_list, # pyre-ignore num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1), + # pyre-ignore + grouping_keys=grouping_key_tensor_name, + # pyre-ignore + cast_keys_to_int=cast_keys_to_int, ) ne.update(**inputs) actual_ne = ne.compute() @@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None: raise -def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: +def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]: return [ # base condition { @@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]: ), # for this case, both tasks have same groupings "expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]), }, + # Custom grouping key tensor name + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0, 1, 0, 1, 1]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + "grouping_key_tensor_name": "custom_key", + }, + # Cast grouping keys to int32 + { + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]), + "weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]), + "grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]), + "expected_ne": torch.tensor([[3.1615, 1.6004]]), + "grouping_key_tensor_name": "custom_key", + "cast_keys_to_int": True, + }, ]