Skip to content

Commit

Permalink
Log stats to Tensorboard (#2712)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2712

# Context
MPZCH has pretty nice tensorboard logging to show ZCH stats, we want the same thing for ZCH.

# Code Decision
Adopt code in D62984548.

Reviewed By: zlzhao1104

Differential Revision: D68737755

fbshipit-source-id: 70fc8e20988c1b8f2d4615ec0f290e8939b38d83
  • Loading branch information
Yihang Yang authored and facebook-github-bot committed Jan 29, 2025
1 parent c6f41aa commit 487fbb7
Showing 1 changed file with 168 additions and 3 deletions.
171 changes: 168 additions & 3 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,157 @@

import torch

from torch import nn
from tensorboard.adhoc import Adhoc
from torch import distributed as dist, nn
from torchrec.modules.embedding_configs import BaseEmbeddingConfig
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


logger: Logger = getLogger(__name__)


class ScalarLogger(torch.nn.Module):
"""
A logger to report various metrics related to ZCH.
This module is adapted from ScalarLogger for multi-probe ZCH.
Args:
name (str): Name of the embedding table.
frequency (int): Frequency of reporting metrics in number of iterations.
Example::
scalar_logger = ScalarLogger(
name=name,
frequency=tb_logging_frequency,
)
"""

def __init__(
self,
name: str,
frequency: int,
) -> None:
"""
Initializes the logger.
Args:
name (str): Name of the embedding table.
frequency (int): Frequency of reporting metrics in number of iterations.
Returns:
None
"""
super().__init__()
self._name: str = name
self._frequency: int = frequency

# size related metrics
self._unused_size: int = 0
self._active_size: int = 0
self._total_size: int = 0

# scalar step
self._scalar_logger_steps: int = 0

def should_report(self) -> bool:
"""
Returns whether the logger should report metrics.
This function only returns True for rank 0 and every self._frequency steps.
"""
if self._scalar_logger_steps % self._frequency != 0:
return False
rank: int = -1
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
return rank == 0

def build_metric_name(
self,
metric: str,
run_type: str,
) -> str:
"""
Builds the metric name for reporting.
Args:
metric (str): Name of the metric.
run_type (str): Run type of the model, e.g. train, eval, etc.
Returns:
str: Metric name.
"""
return f"mc_zch_stats/{self._name}/{metric}/{run_type}"

def update_size(
self,
counts: torch.Tensor,
) -> None:
"""
Updates the size related metrics.
Args:
counts (torch.Tensor): Counts of each id in the embedding table.
Returns:
None
"""
zero_counts = counts == 0
self._unused_size = int(torch.sum(zero_counts).item())

self._total_size = counts.shape[0]
self._active_size = self._total_size - self._unused_size

def forward(
self,
run_type: str,
) -> None:
"""
Reports various metrics related to ZCH.
Args:
run_type (str): Run type of the model, e.g. train, eval, etc.
Returns:
None
"""
if self.should_report():
total_size = self._total_size + 0.001
usused_ratio = round(self._unused_size / total_size, 3)
active_ratio = round(self._active_size / total_size, 3)

Adhoc.writer().add_scalar(
self.build_metric_name("unused_size", run_type),
self._unused_size,
self._scalar_logger_steps,
)
Adhoc.writer().add_scalar(
self.build_metric_name("usused_ratio", run_type),
usused_ratio,
self._scalar_logger_steps,
)
Adhoc.writer().add_scalar(
self.build_metric_name("active_size", run_type),
self._active_size,
self._scalar_logger_steps,
)
Adhoc.writer().add_scalar(
self.build_metric_name("active_ratio", run_type),
active_ratio,
self._scalar_logger_steps,
)

logger.info(f"{self._name=}, {run_type=}")
logger.info(f"{self._total_size=}")
logger.info(f"{self._unused_size=}, {usused_ratio=}")
logger.info(f"{self._active_size=}, {active_ratio=}")

# reset after reporting
self._unused_size = 0
self._active_size = 0
self._total_size = 0

self._scalar_logger_steps += 1


@torch.fx.wrap
def apply_mc_method_to_jt_dict(
mc_module: nn.Module,
Expand Down Expand Up @@ -983,6 +1126,7 @@ def __init__(
output_global_offset: int = 0, # typically not provided by user
output_segments: Optional[List[int]] = None, # typically not provided by user
buckets: int = 1,
tb_logging_frequency: int = 0,
) -> None:
if output_segments is None:
output_segments = [output_global_offset, output_global_offset + zch_size]
Expand Down Expand Up @@ -1020,6 +1164,16 @@ def __init__(
self._evicted: bool = False
self._last_eviction_iter: int = -1

## ------ logging ------
self._tb_logging_frequency = tb_logging_frequency
self._scalar_logger: Optional[ScalarLogger] = None
if self._tb_logging_frequency > 0:
assert self._name is not None, "name must be provided for logging"
self._scalar_logger = ScalarLogger(
name=self._name,
frequency=self._tb_logging_frequency,
)

def _init_buffers(self) -> None:
self.register_buffer(
"_mch_sorted_raw_ids",
Expand Down Expand Up @@ -1305,16 +1459,26 @@ def profile(
self._coalesce_history()
self._last_eviction_iter = self._current_iter

if self._scalar_logger is not None:
self._scalar_logger.update_size(counts=self._mch_metadata["counts"])

return features

def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
return _mch_remap(
remapped_features = _mch_remap(
features,
self._mch_sorted_raw_ids,
self._mch_remapped_ids_mapping,
self._output_global_offset + self._zch_size - 1,
)

if self._scalar_logger is not None:
self._scalar_logger(
run_type="train" if self.training else "eval",
)

return remapped_features

def forward(
self,
features: Dict[str, JaggedTensor],
Expand Down Expand Up @@ -1393,4 +1557,5 @@ def rebuild_with_output_id_range(
output_global_offset=output_id_range[0],
output_segments=output_segments,
buckets=len(output_segments) - 1,
tb_logging_frequency=self._tb_logging_frequency,
)

0 comments on commit 487fbb7

Please sign in to comment.