Skip to content

Commit 5012ccc

Browse files
felipemello1Felipe Mello
andauthored
Metric Logging updates 3/N (meta-pytorch#359)
* commit * commit * update backend role typehints and enum * update where we check FORGE_DISABLE_METRICS * remove protected import * protect import * record_metric uses dataclass Metric * commit * revert * docs/names --------- Co-authored-by: Felipe Mello <[email protected]>
1 parent 2d1cc85 commit 5012ccc

File tree

2 files changed

+84
-66
lines changed

2 files changed

+84
-66
lines changed

src/forge/observability/perf_tracker.py

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
67
import inspect
78
import logging
89
import os
910
import threading
1011
import time
11-
1212
from concurrent.futures import Future, ThreadPoolExecutor
1313
from functools import lru_cache, wraps
1414
from typing import Protocol
@@ -18,6 +18,8 @@
1818
from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU
1919
from forge.observability.metrics import record_metric, Reduce
2020

21+
logger = logging.getLogger(__name__)
22+
2123
# Thread-local memory tracking state
2224
_local = threading.local()
2325

@@ -44,7 +46,6 @@ def _warn_nested_memory_tracking(prefix: str) -> None:
4446

4547
"""
4648
47-
4849
class Tracer:
4950
==========
5051
"""
@@ -150,10 +151,9 @@ def stop(self) -> None:
150151
if not self._active:
151152
raise ValueError("Tracer must be started before calling stop")
152153

153-
# Stop timing (always enabled)
154-
# step("end") is dropped from steps, but included in total sum
155-
self._timer.step("end") # pyre-ignore
156-
self._record_timing_metrics()
154+
# Stop timing
155+
durations, stop_step_ms = self._timer.get_all_durations() # pyre-ignore
156+
self._record_timing_metrics(durations, stop_step_ms)
157157
self._timer = None
158158

159159
# Stop memory tracking
@@ -193,17 +193,15 @@ def _stop_memory_tracking(self) -> None:
193193
torch.cuda.reset_max_memory_allocated()
194194
self._memory_started = False
195195

196-
def _record_timing_metrics(self) -> None:
197-
durations = self._timer.get_all_durations() # pyre-ignore
198-
199-
# Total: sum all recorded durations (full timeline including end)
200-
total_ms = sum(d_ms for name, d_ms in durations)
196+
def _record_timing_metrics(
197+
self, durations: list[tuple[str, float]], stop_step_ms: float
198+
) -> None:
199+
total_ms = sum(d_ms for _, d_ms in durations) + stop_step_ms
201200
total_s = total_ms / 1000.0
202201
record_metric(f"{self.prefix}/total_duration_avg_s", total_s, Reduce.MEAN)
203202
record_metric(f"{self.prefix}/total_duration_max_s", total_s, Reduce.MAX)
204203

205-
# Steps: record each individually (drop last "end")
206-
for name, d_ms in durations[:-1]:
204+
for name, d_ms in durations:
207205
d_s = d_ms / 1000.0
208206
record_metric(f"{self.prefix}/{name}/duration_avg_s", d_s, Reduce.MEAN)
209207
record_metric(f"{self.prefix}/{name}/duration_max_s", d_s, Reduce.MAX)
@@ -216,7 +214,7 @@ def start(self) -> None:
216214
def step(self, name: str) -> None:
217215
...
218216

219-
def get_all_durations(self) -> list[tuple[str, float]]:
217+
def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
220218
...
221219

222220

@@ -242,13 +240,27 @@ def step(self, name: str) -> None:
242240
self._durations.append((name, delta_ms))
243241
self._chain_start = now
244242

245-
def get_all_durations(self) -> list[tuple[str, float]]:
246-
return self._durations[:]
243+
def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
244+
"""Retrieve list of (step_name, duration) tuples and last step duration
245+
between tracer.stop and the last step (or start if none)."""
246+
stop_step_ms = 0.0
247+
if self._chain_start is not None:
248+
now = time.perf_counter()
249+
stop_step_ms = (now - self._chain_start) * 1000
250+
return self._durations[:], stop_step_ms
247251

248252

249253
class _TimerCUDA(_TimerProtocol):
250254
"""CUDA timing backend with non-blocking events and futures.
251255
Uses a thread pool to poll CUDA events asynchronously without blocking the main thread.
256+
257+
Example:
258+
timer = _TimerCUDA()
259+
timer.start()
260+
# torch.mm(a, b) # ~100ms GPU
261+
timer.step("matmul")
262+
# torch.mm(c, d) # ~200ms
263+
durs_steps, stop_step_ms = timer.get_all_durations() # ([( "matmul", 100 )], 200)
252264
"""
253265

254266
def __init__(self, max_workers: int = 2) -> None:
@@ -277,74 +289,70 @@ def step(self, name: str) -> None:
277289
Args:
278290
name: Label for this segment's duration
279291
"""
280-
# Submit polling future; chain to next event.
281292
if self._chain_start is None:
282293
raise ValueError("Timer must be started before calling step")
283294

284295
stream = torch.cuda.current_stream()
285296
end_event = torch.cuda.Event(enable_timing=True)
286297
end_event.record(stream)
287298

288-
def _compute_elapsed(start_event, end_event):
289-
# Poll with backoff: starts fast (1ms), grows to cap (50ms) for mixed workloads.
290-
sleep_time = 0.001 # Start at 1ms
291-
while not end_event.query():
292-
time.sleep(sleep_time)
293-
sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms
294-
return start_event.elapsed_time(end_event)
295-
296-
future = self._executor.submit(_compute_elapsed, self._chain_start, end_event)
299+
future = self._executor.submit(self._poll_elapsed, self._chain_start, end_event)
297300
index = len(self._futures)
298301
self._futures.append((name, future, index))
299-
300302
if len(self._futures) >= 5: # clean up every 5
301303
self._collect_completed_futures()
302304

303305
self._chain_start = end_event
304306

305-
def _collect_completed_futures(self) -> None:
307+
def _poll_elapsed(
308+
self, start_event: torch.cuda.Event, end_event: torch.cuda.Event
309+
) -> float:
310+
"""Compute elapsed time after polling with backoff."""
311+
# Poll until ready
312+
sleep_time = 0.001 # Start at 1ms
313+
while not end_event.query():
314+
time.sleep(sleep_time)
315+
sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms
316+
return start_event.elapsed_time(end_event)
317+
318+
def _collect_completed_futures(self, wait_till_done: bool = False) -> None:
306319
"""Drain done futures to avoid memory leak; update durations in submission order."""
307-
completed = []
308320
still_pending = []
309321
for name, future, idx in self._futures:
310-
if future.done():
311-
try:
312-
dur = future.result()
313-
completed.append((idx, name, dur))
314-
except Exception as e:
315-
raise RuntimeError(f"Timing failed for {name}: {e}") from e
322+
if future.done() or wait_till_done:
323+
dur = future.result()
324+
self._durations.append((name, dur))
316325
else:
317326
still_pending.append((name, future, idx))
318327

319-
# Sort completed by submission index to preserve order
320-
completed.sort(key=lambda x: x[0])
321-
for _, name, dur in completed:
322-
self._durations.append((name, dur))
323-
324328
self._futures = still_pending
325329

326-
def get_all_durations(self) -> list[tuple[str, float]]:
327-
"""Retrieve list of (name, duration) tuples in submission order after waiting for background polls to finish."""
328-
# Wait and collect if pendings; return durations.
329-
self._collect_completed_futures()
330-
completed = []
331-
for name, future, idx in self._futures:
332-
try:
333-
dur = future.result()
334-
completed.append((idx, name, dur))
335-
except Exception as e:
336-
raise RuntimeError(f"Timing failed for {name}: {e}") from e
337-
338-
# Sort by submission index to preserve order
339-
completed.sort(key=lambda x: x[0])
340-
for _, name, dur in completed:
341-
self._durations.append((name, dur))
330+
def get_all_durations(self) -> tuple[list[tuple[str, float]], float]:
331+
"""Retrieve list of (step_name, duration) tuples and last step duration
332+
between tracer.stop and the last step (or start if none). Order of tuples is random.
333+
"""
334+
# Final timing since last step (or start) until this function is called
335+
stop_step = f"_stop_step_{id(self)}"
336+
self.step(stop_step)
342337

338+
# Wait on remaining futures
339+
self._collect_completed_futures(wait_till_done=True)
343340
self._futures.clear()
344-
return self._durations[:]
341+
342+
# Extract stop_step_ms
343+
stop_step_ms = 0.0
344+
durations = [
345+
(name, duration) for name, duration in self._durations if name != stop_step
346+
]
347+
for name, duration in self._durations:
348+
if name == stop_step:
349+
stop_step_ms = duration
350+
break
351+
352+
return durations, stop_step_ms
345353

346354
def __del__(self) -> None:
347-
# Fallback cleanup in finalizer; ignores errors to avoid shutdown noise.
355+
# Fallback cleanup in finalizer
348356
try:
349357
self._executor.shutdown(wait=True)
350358
except Exception:

tests/unit_tests/observability/test_perf_tracker.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -309,29 +309,39 @@ def test_tracer_and_timer_reuse(self, mock_record_metric_calls):
309309
cpu_timer.start()
310310
time.sleep(0.005)
311311
cpu_timer.step("cpu_step1")
312-
durations1 = cpu_timer.get_all_durations()
312+
cpu_durations_list1, cpu_final_ms1 = cpu_timer.get_all_durations()
313313

314314
cpu_timer.start()
315315
time.sleep(0.005)
316316
cpu_timer.step("cpu_step2")
317-
durations2 = cpu_timer.get_all_durations()
317+
cpu_durations_list2, cpu_final_ms2 = cpu_timer.get_all_durations()
318318

319-
assert len(durations1) == 1 and durations1[0][0] == "cpu_step1"
320-
assert len(durations2) == 1 and durations2[0][0] == "cpu_step2"
319+
assert (
320+
len(cpu_durations_list1) == 1 and cpu_durations_list1[0][0] == "cpu_step1"
321+
)
322+
assert (
323+
len(cpu_durations_list2) == 1 and cpu_durations_list2[0][0] == "cpu_step2"
324+
)
321325

322326
# Test CUDA timer reuse (if available)
323327
if torch.cuda.is_available():
324328
cuda_timer = _TimerCUDA()
325329
cuda_timer.start()
326330
cuda_timer.step("cuda_step1")
327-
cuda_durations1 = cuda_timer.get_all_durations()
331+
cuda_durations_list1, cuda_final_ms1 = cuda_timer.get_all_durations()
328332

329333
cuda_timer.start()
330334
cuda_timer.step("cuda_step2")
331-
cuda_durations2 = cuda_timer.get_all_durations()
335+
cuda_durations_list2, cuda_final_ms2 = cuda_timer.get_all_durations()
332336

333-
assert len(cuda_durations1) == 1 and cuda_durations1[0][0] == "cuda_step1"
334-
assert len(cuda_durations2) == 1 and cuda_durations2[0][0] == "cuda_step2"
337+
assert (
338+
len(cuda_durations_list1) == 1
339+
and cuda_durations_list1[0][0] == "cuda_step1"
340+
)
341+
assert (
342+
len(cuda_durations_list2) == 1
343+
and cuda_durations_list2[0][0] == "cuda_step2"
344+
)
335345

336346
def test_exception_handling_context_manager(self, mock_record_metric_calls):
337347
"""Test context manager properly cleans up on exception."""

0 commit comments

Comments
 (0)