Skip to content

Commit

Permalink
2025-01-17 nightly release (33168a1)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 17, 2025
1 parent 29755d8 commit c9cfe23
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 17 deletions.
4 changes: 4 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
bucketize_pos: bool = False,
block_bucketize_pos: Optional[List[torch.Tensor]] = None,
total_num_blocks: Optional[torch.Tensor] = None,
keep_original_indices: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
Expand Down Expand Up @@ -159,6 +160,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference(
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=block_bucketize_pos,
return_bucket_mapping=True,
keep_orig_idx=keep_original_indices,
)

return (
Expand Down Expand Up @@ -305,6 +307,7 @@ def bucketize_kjt_inference(
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
is_sequence: bool = False,
keep_original_indices: bool = False,
) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets,
Expand Down Expand Up @@ -352,6 +355,7 @@ def bucketize_kjt_inference(
total_num_blocks=total_num_buckets_new_type,
bucketize_pos=bucketize_pos,
block_bucketize_pos=block_bucketize_row_pos,
keep_original_indices=keep_original_indices,
)
else:
(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,7 @@ def _create_input_dists(
has_feature_processor=sharding._has_feature_processor,
need_pos=False,
embedding_shard_metadata=emb_sharding,
keep_original_indices=True,
)
self._input_dists.append(input_dist)

Expand Down
4 changes: 4 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,12 @@ def __init__(
has_feature_processor: bool = False,
need_pos: bool = False,
embedding_shard_metadata: Optional[List[List[int]]] = None,
keep_original_indices: bool = False,
) -> None:
super().__init__()
logger.info(
f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}"
f", keep_original_indices={keep_original_indices}"
)
self._world_size: int = world_size
self._num_features = num_features
Expand Down Expand Up @@ -683,6 +685,7 @@ def __init__(
self._embedding_shard_metadata: Optional[List[List[int]]] = (
embedding_shard_metadata
)
self._keep_original_indices = keep_original_indices

def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
block_sizes, block_bucketize_row_pos = get_block_sizes_runtime_device(
Expand Down Expand Up @@ -717,6 +720,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs:
block_bucketize_row_pos
),
is_sequence=self._is_sequence,
keep_original_indices=self._keep_original_indices,
)
# KJTOneToAll
dist_kjt = self._dist.forward(bucketized_features)
Expand Down
5 changes: 0 additions & 5 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,6 @@ def validate_batch_size_stages(
if len(batch_size_stages) == 0:
raise ValueError("Batch size stages should not be empty")

for i in range(len(batch_size_stages) - 1):
if batch_size_stages[i].batch_size >= batch_size_stages[i + 1].batch_size:
raise ValueError(
f"Batch size should be in ascending order. Got {batch_size_stages}"
)
if batch_size_stages[-1].max_iters is not None:
raise ValueError(
f"Batch size stages last stage should have max_iters = None, but get {batch_size_stages[-1].max_iters}"
Expand Down
38 changes: 29 additions & 9 deletions torchrec/metrics/segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,25 @@ class SegmentedNEMetricComputation(RecMetricComputation):
Args:
include_logloss (bool): return vanilla logloss as one of metrics results, on top of segmented NE.
num_groups (int): number of groups to segment NE by.
grouping_keys (str): name of the tensor containing the label by which results will be segmented. This tensor should be of type torch.int64.
cast_keys_to_int (bool): whether to cast grouping_keys to torch.int64. Only works if grouping_keys is of type torch.float32.
"""

def __init__(
self,
*args: Any,
include_logloss: bool = False, # TODO - include
num_groups: int = 1,
grouping_keys: str = "grouping_keys",
cast_keys_to_int: bool = False,
**kwargs: Any,
) -> None:
self._include_logloss: bool = include_logloss
super().__init__(*args, **kwargs)
self._num_groups = num_groups # would there be checkpointing issues with this? maybe make this state
self._grouping_keys = grouping_keys
self._cast_keys_to_int = cast_keys_to_int
self._add_state(
"cross_entropy_sum",
torch.zeros((self._n_tasks, num_groups), dtype=torch.double),
Expand Down Expand Up @@ -217,21 +224,30 @@ def update(
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' and 'grouping_keys' should not be None for NEMetricComputation update"
f"Inputs 'predictions' and 'weights' and '{self._grouping_keys}' should not be None for NEMetricComputation update"
)
elif (
"required_inputs" not in kwargs
or kwargs["required_inputs"].get("grouping_keys") is None
or kwargs["required_inputs"].get(self._grouping_keys) is None
):
raise RecMetricException(
f"Required inputs for SegmentedNEMetricComputation update should contain 'grouping_keys', got kwargs: {kwargs}"
)
elif kwargs["required_inputs"]["grouping_keys"].dtype != torch.int64:
raise RecMetricException(
f"Grouping keys must have type torch.int64, got {kwargs['required_inputs']['grouping_keys'].dtype}."
f"Required inputs for SegmentedNEMetricComputation update should contain {self._grouping_keys}, got kwargs: {kwargs}"
)
elif kwargs["required_inputs"][self._grouping_keys].dtype != torch.int64:
if (
self._cast_keys_to_int
and kwargs["required_inputs"][self._grouping_keys].dtype
== torch.float32
):
kwargs["required_inputs"][self._grouping_keys] = kwargs[
"required_inputs"
][self._grouping_keys].to(torch.int64)
else:
raise RecMetricException(
f"Grouping keys expected to have type torch.int64 or torch.float32 with cast_keys_to_int set to true, got {kwargs['required_inputs'][self._grouping_keys].dtype}."
)

grouping_keys = kwargs["required_inputs"]["grouping_keys"]
grouping_keys = kwargs["required_inputs"][self._grouping_keys]
states = get_segemented_ne_states(
labels,
predictions,
Expand Down Expand Up @@ -325,4 +341,8 @@ def __init__(
process_group=process_group,
**kwargs,
)
self._required_inputs.add("grouping_keys")
if "grouping_keys" not in kwargs:
self._required_inputs.add("grouping_keys")
else:
# pyre-ignore[6]
self._required_inputs.add(kwargs["grouping_keys"])
31 changes: 28 additions & 3 deletions torchrec/metrics/tests/test_segmented_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from typing import Dict, Iterable, Union
from typing import Any, Dict, Iterable, Union

import torch
from torch import no_grad
Expand All @@ -31,6 +31,8 @@ def _test_segemented_ne_helper(
weights: torch.Tensor,
expected_ne: torch.Tensor,
grouping_keys: torch.Tensor,
grouping_key_tensor_name: str = "grouping_keys",
cast_keys_to_int: bool = False,
) -> None:
num_task = labels.shape[0]
batch_size = labels.shape[0]
Expand All @@ -41,7 +43,7 @@ def _test_segemented_ne_helper(
"weights": {},
}
if grouping_keys is not None:
inputs["required_inputs"] = {"grouping_keys": grouping_keys}
inputs["required_inputs"] = {grouping_key_tensor_name: grouping_keys}
for i in range(num_task):
task_info = RecTaskInfo(
name=f"Task:{i}",
Expand All @@ -64,6 +66,10 @@ def _test_segemented_ne_helper(
tasks=task_list,
# pyre-ignore
num_groups=max(2, torch.unique(grouping_keys)[-1].item() + 1),
# pyre-ignore
grouping_keys=grouping_key_tensor_name,
# pyre-ignore
cast_keys_to_int=cast_keys_to_int,
)
ne.update(**inputs)
actual_ne = ne.compute()
Expand Down Expand Up @@ -95,7 +101,7 @@ def test_grouped_ne(self) -> None:
raise


def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
def generate_model_outputs_cases() -> Iterable[Dict[str, Any]]:
return [
# base condition
{
Expand Down Expand Up @@ -149,4 +155,23 @@ def generate_model_outputs_cases() -> Iterable[Dict[str, torch._tensor.Tensor]]:
), # for this case, both tasks have same groupings
"expected_ne": torch.tensor([[3.1615, 1.6004], [1.0034, 0.4859]]),
},
# Custom grouping key tensor name
{
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
"grouping_keys": torch.tensor([0, 1, 0, 1, 1]),
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
"grouping_key_tensor_name": "custom_key",
},
# Cast grouping keys to int32
{
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"predictions": torch.tensor([[0.2, 0.6, 0.8, 0.4, 0.9]]),
"weights": torch.tensor([[0.13, 0.2, 0.5, 0.8, 0.75]]),
"grouping_keys": torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0]),
"expected_ne": torch.tensor([[3.1615, 1.6004]]),
"grouping_key_tensor_name": "custom_key",
"cast_keys_to_int": True,
},
]

0 comments on commit c9cfe23

Please sign in to comment.