From 487fbb7bcf79ebb350db60dc8702015a0c7e4095 Mon Sep 17 00:00:00 2001 From: Yihang Yang Date: Wed, 29 Jan 2025 13:27:55 -0800 Subject: [PATCH] Log stats to Tensorboard (#2712) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2712 # Context MPZCH has pretty nice tensorboard logging to show ZCH stats, we want the same thing for ZCH. # Code Decision Adopt code in D62984548. Reviewed By: zlzhao1104 Differential Revision: D68737755 fbshipit-source-id: 70fc8e20988c1b8f2d4615ec0f290e8939b38d83 --- torchrec/modules/mc_modules.py | 171 ++++++++++++++++++++++++++++++++- 1 file changed, 168 insertions(+), 3 deletions(-) diff --git a/torchrec/modules/mc_modules.py b/torchrec/modules/mc_modules.py index 4693fd39c..9de0670b5 100644 --- a/torchrec/modules/mc_modules.py +++ b/torchrec/modules/mc_modules.py @@ -15,14 +15,157 @@ import torch -from torch import nn +from tensorboard.adhoc import Adhoc +from torch import distributed as dist, nn from torchrec.modules.embedding_configs import BaseEmbeddingConfig from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor - logger: Logger = getLogger(__name__) +class ScalarLogger(torch.nn.Module): + """ + A logger to report various metrics related to ZCH. + This module is adapted from ScalarLogger for multi-probe ZCH. + + Args: + name (str): Name of the embedding table. + frequency (int): Frequency of reporting metrics in number of iterations. + + Example:: + scalar_logger = ScalarLogger( + name=name, + frequency=tb_logging_frequency, + ) + """ + + def __init__( + self, + name: str, + frequency: int, + ) -> None: + """ + Initializes the logger. + + Args: + name (str): Name of the embedding table. + frequency (int): Frequency of reporting metrics in number of iterations. + + Returns: + None + """ + super().__init__() + self._name: str = name + self._frequency: int = frequency + + # size related metrics + self._unused_size: int = 0 + self._active_size: int = 0 + self._total_size: int = 0 + + # scalar step + self._scalar_logger_steps: int = 0 + + def should_report(self) -> bool: + """ + Returns whether the logger should report metrics. + This function only returns True for rank 0 and every self._frequency steps. + """ + if self._scalar_logger_steps % self._frequency != 0: + return False + rank: int = -1 + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + return rank == 0 + + def build_metric_name( + self, + metric: str, + run_type: str, + ) -> str: + """ + Builds the metric name for reporting. + + Args: + metric (str): Name of the metric. + run_type (str): Run type of the model, e.g. train, eval, etc. + + Returns: + str: Metric name. + """ + return f"mc_zch_stats/{self._name}/{metric}/{run_type}" + + def update_size( + self, + counts: torch.Tensor, + ) -> None: + """ + Updates the size related metrics. + + Args: + counts (torch.Tensor): Counts of each id in the embedding table. + + Returns: + None + """ + zero_counts = counts == 0 + self._unused_size = int(torch.sum(zero_counts).item()) + + self._total_size = counts.shape[0] + self._active_size = self._total_size - self._unused_size + + def forward( + self, + run_type: str, + ) -> None: + """ + Reports various metrics related to ZCH. + + Args: + run_type (str): Run type of the model, e.g. train, eval, etc. + + Returns: + None + """ + if self.should_report(): + total_size = self._total_size + 0.001 + usused_ratio = round(self._unused_size / total_size, 3) + active_ratio = round(self._active_size / total_size, 3) + + Adhoc.writer().add_scalar( + self.build_metric_name("unused_size", run_type), + self._unused_size, + self._scalar_logger_steps, + ) + Adhoc.writer().add_scalar( + self.build_metric_name("usused_ratio", run_type), + usused_ratio, + self._scalar_logger_steps, + ) + Adhoc.writer().add_scalar( + self.build_metric_name("active_size", run_type), + self._active_size, + self._scalar_logger_steps, + ) + Adhoc.writer().add_scalar( + self.build_metric_name("active_ratio", run_type), + active_ratio, + self._scalar_logger_steps, + ) + + logger.info(f"{self._name=}, {run_type=}") + logger.info(f"{self._total_size=}") + logger.info(f"{self._unused_size=}, {usused_ratio=}") + logger.info(f"{self._active_size=}, {active_ratio=}") + + # reset after reporting + self._unused_size = 0 + self._active_size = 0 + self._total_size = 0 + + self._scalar_logger_steps += 1 + + @torch.fx.wrap def apply_mc_method_to_jt_dict( mc_module: nn.Module, @@ -983,6 +1126,7 @@ def __init__( output_global_offset: int = 0, # typically not provided by user output_segments: Optional[List[int]] = None, # typically not provided by user buckets: int = 1, + tb_logging_frequency: int = 0, ) -> None: if output_segments is None: output_segments = [output_global_offset, output_global_offset + zch_size] @@ -1020,6 +1164,16 @@ def __init__( self._evicted: bool = False self._last_eviction_iter: int = -1 + ## ------ logging ------ + self._tb_logging_frequency = tb_logging_frequency + self._scalar_logger: Optional[ScalarLogger] = None + if self._tb_logging_frequency > 0: + assert self._name is not None, "name must be provided for logging" + self._scalar_logger = ScalarLogger( + name=self._name, + frequency=self._tb_logging_frequency, + ) + def _init_buffers(self) -> None: self.register_buffer( "_mch_sorted_raw_ids", @@ -1305,16 +1459,26 @@ def profile( self._coalesce_history() self._last_eviction_iter = self._current_iter + if self._scalar_logger is not None: + self._scalar_logger.update_size(counts=self._mch_metadata["counts"]) + return features def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: - return _mch_remap( + remapped_features = _mch_remap( features, self._mch_sorted_raw_ids, self._mch_remapped_ids_mapping, self._output_global_offset + self._zch_size - 1, ) + if self._scalar_logger is not None: + self._scalar_logger( + run_type="train" if self.training else "eval", + ) + + return remapped_features + def forward( self, features: Dict[str, JaggedTensor], @@ -1393,4 +1557,5 @@ def rebuild_with_output_id_range( output_global_offset=output_id_range[0], output_segments=output_segments, buckets=len(output_segments) - 1, + tb_logging_frequency=self._tb_logging_frequency, )