From 8790fc4769683c91736a88e5f82f71d6605d9d02 Mon Sep 17 00:00:00 2001 From: Yue Dong Date: Wed, 21 Feb 2024 22:29:53 -0800 Subject: [PATCH] Enable memory snapshot support upload to manifold and zoomer (#709) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/709 This change adds the support to upload memory snapshot to manifold and shown in zoomer with following changes: 1. Add a zoomer specific memory snapshot profiler wrapper; 2. Internally call the memory_snapshot API from `unitrace`. Reviewed By: aaronenyeshi Differential Revision: D53997537 fbshipit-source-id: 2af6cec9cba64f43c6321e4efd497b373db10bd3 --- .../callbacks/test_memory_snapshot.py | 7 +-- .../framework/callbacks/memory_snapshot.py | 16 ++----- torchtnt/utils/memory_snapshot_profiler.py | 44 ++++++++++++++----- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/tests/framework/callbacks/test_memory_snapshot.py b/tests/framework/callbacks/test_memory_snapshot.py index e039ea1767..0a1a615673 100644 --- a/tests/framework/callbacks/test_memory_snapshot.py +++ b/tests/framework/callbacks/test_memory_snapshot.py @@ -10,13 +10,14 @@ from torchtnt.framework.callbacks.memory_snapshot import MemorySnapshot from torchtnt.framework.state import EntryPoint +from torchtnt.utils.memory_snapshot_profiler import MemorySnapshotProfiler class TestMemorySnapshot(unittest.TestCase): def test_on_train_step_end(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: memory_snapshot = MemorySnapshot( - output_dir=temp_dir, + memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir), ) memory_snapshot.memory_snapshot_profiler = Mock() @@ -28,7 +29,7 @@ def test_on_train_step_end(self) -> None: def test_on_eval_step_end(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: memory_snapshot = MemorySnapshot( - output_dir=temp_dir, + memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir), ) memory_snapshot.memory_snapshot_profiler = Mock() @@ -41,7 +42,7 @@ def test_on_eval_step_end(self) -> None: def test_on_predict_step_end(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: memory_snapshot = MemorySnapshot( - output_dir=temp_dir, + memory_snapshot_profiler=MemorySnapshotProfiler(output_dir=temp_dir), ) memory_snapshot.memory_snapshot_profiler = Mock() diff --git a/torchtnt/framework/callbacks/memory_snapshot.py b/torchtnt/framework/callbacks/memory_snapshot.py index 061750b34a..a920885fbe 100644 --- a/torchtnt/framework/callbacks/memory_snapshot.py +++ b/torchtnt/framework/callbacks/memory_snapshot.py @@ -5,15 +5,11 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit -from torchtnt.utils.memory_snapshot_profiler import ( - MemorySnapshotParams, - MemorySnapshotProfiler, -) +from torchtnt.utils.memory_snapshot_profiler import MemorySnapshotProfilerBase logger: logging.Logger = logging.getLogger(__name__) @@ -25,8 +21,7 @@ class MemorySnapshot(Callback): Uses `Memory Snapshots `. Args: - output_dir: Directory where to save the memory snapshots. - memory_snapshot_params: Instance of MemorySnapshotParams which will be passed to MemorySnapshotProfiler. + memory_snapshot_profiler: Instance of MemorySnapshotProfilerBase, controls when and where to save the memory snapshots. Note: It is recommended to instantiate this callback **as early as possible** in your training/eval/prediction script, ideally before model initialization, to make sure all memory allocation is captured. @@ -36,12 +31,9 @@ class MemorySnapshot(Callback): def __init__( self, *, - output_dir: str, - memory_snapshot_params: Optional[MemorySnapshotParams] = None, + memory_snapshot_profiler: MemorySnapshotProfilerBase, ) -> None: - self.memory_snapshot_profiler = MemorySnapshotProfiler( - output_dir=output_dir, memory_snapshot_params=memory_snapshot_params - ) + self.memory_snapshot_profiler = memory_snapshot_profiler def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: self.memory_snapshot_profiler.step() diff --git a/torchtnt/utils/memory_snapshot_profiler.py b/torchtnt/utils/memory_snapshot_profiler.py index 011eef4f70..8fff2cb7df 100644 --- a/torchtnt/utils/memory_snapshot_profiler.py +++ b/torchtnt/utils/memory_snapshot_profiler.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +from abc import ABC, abstractmethod from dataclasses import dataclass from types import TracebackType from typing import Optional, Type @@ -39,7 +40,36 @@ class MemorySnapshotParams: enable_oom_observer: bool = True -class MemorySnapshotProfiler: +class MemorySnapshotProfilerBase(ABC): + """ + Base class for memory snapshot profiler. + """ + + def __enter__(self) -> None: + self.start() + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + self.stop() + + @abstractmethod + def start(self) -> None: + ... + + @abstractmethod + def stop(self) -> None: + ... + + @abstractmethod + def step(self) -> None: + ... + + +class MemorySnapshotProfiler(MemorySnapshotProfilerBase): """ Records a history of memory allocation and free events, and dumps to a file which can be visualized offline. It by default keeps track of @@ -71,6 +101,7 @@ def __init__( output_dir: str, memory_snapshot_params: Optional[MemorySnapshotParams] = None, ) -> None: + super().__init__() self.output_dir: str = output_dir self.params: MemorySnapshotParams = ( memory_snapshot_params or MemorySnapshotParams() @@ -115,17 +146,6 @@ def __init__( f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}." ) - def __enter__(self) -> None: - self.start() - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - tb: Optional[TracebackType], - ) -> Optional[bool]: - self.stop() - def start(self) -> None: if not torch.cuda.is_available(): logger.warn("CUDA unavailable. Not recording memory history.")