33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+
67import inspect
78import logging
89import os
910import threading
1011import time
11-
1212from concurrent .futures import Future , ThreadPoolExecutor
1313from functools import lru_cache , wraps
1414from typing import Protocol
1818from forge .env_constants import DISABLE_PERF_METRICS , METRIC_TIMER_USES_GPU
1919from forge .observability .metrics import record_metric , Reduce
2020
21+ logger = logging .getLogger (__name__ )
22+
2123# Thread-local memory tracking state
2224_local = threading .local ()
2325
@@ -44,7 +46,6 @@ def _warn_nested_memory_tracking(prefix: str) -> None:
4446
4547"""
4648
47-
4849class Tracer:
4950==========
5051"""
@@ -150,10 +151,9 @@ def stop(self) -> None:
150151 if not self ._active :
151152 raise ValueError ("Tracer must be started before calling stop" )
152153
153- # Stop timing (always enabled)
154- # step("end") is dropped from steps, but included in total sum
155- self ._timer .step ("end" ) # pyre-ignore
156- self ._record_timing_metrics ()
154+ # Stop timing
155+ durations , stop_step_ms = self ._timer .get_all_durations () # pyre-ignore
156+ self ._record_timing_metrics (durations , stop_step_ms )
157157 self ._timer = None
158158
159159 # Stop memory tracking
@@ -193,17 +193,15 @@ def _stop_memory_tracking(self) -> None:
193193 torch .cuda .reset_max_memory_allocated ()
194194 self ._memory_started = False
195195
196- def _record_timing_metrics (self ) -> None :
197- durations = self ._timer .get_all_durations () # pyre-ignore
198-
199- # Total: sum all recorded durations (full timeline including end)
200- total_ms = sum (d_ms for name , d_ms in durations )
196+ def _record_timing_metrics (
197+ self , durations : list [tuple [str , float ]], stop_step_ms : float
198+ ) -> None :
199+ total_ms = sum (d_ms for _ , d_ms in durations ) + stop_step_ms
201200 total_s = total_ms / 1000.0
202201 record_metric (f"{ self .prefix } /total_duration_avg_s" , total_s , Reduce .MEAN )
203202 record_metric (f"{ self .prefix } /total_duration_max_s" , total_s , Reduce .MAX )
204203
205- # Steps: record each individually (drop last "end")
206- for name , d_ms in durations [:- 1 ]:
204+ for name , d_ms in durations :
207205 d_s = d_ms / 1000.0
208206 record_metric (f"{ self .prefix } /{ name } /duration_avg_s" , d_s , Reduce .MEAN )
209207 record_metric (f"{ self .prefix } /{ name } /duration_max_s" , d_s , Reduce .MAX )
@@ -216,7 +214,7 @@ def start(self) -> None:
216214 def step (self , name : str ) -> None :
217215 ...
218216
219- def get_all_durations (self ) -> list [tuple [str , float ]]:
217+ def get_all_durations (self ) -> tuple [ list [tuple [str , float ]], float ]:
220218 ...
221219
222220
@@ -242,13 +240,27 @@ def step(self, name: str) -> None:
242240 self ._durations .append ((name , delta_ms ))
243241 self ._chain_start = now
244242
245- def get_all_durations (self ) -> list [tuple [str , float ]]:
246- return self ._durations [:]
243+ def get_all_durations (self ) -> tuple [list [tuple [str , float ]], float ]:
244+ """Retrieve list of (step_name, duration) tuples and last step duration
245+ between tracer.stop and the last step (or start if none)."""
246+ stop_step_ms = 0.0
247+ if self ._chain_start is not None :
248+ now = time .perf_counter ()
249+ stop_step_ms = (now - self ._chain_start ) * 1000
250+ return self ._durations [:], stop_step_ms
247251
248252
249253class _TimerCUDA (_TimerProtocol ):
250254 """CUDA timing backend with non-blocking events and futures.
251255 Uses a thread pool to poll CUDA events asynchronously without blocking the main thread.
256+
257+ Example:
258+ timer = _TimerCUDA()
259+ timer.start()
260+ # torch.mm(a, b) # ~100ms GPU
261+ timer.step("matmul")
262+ # torch.mm(c, d) # ~200ms
263+ durs_steps, stop_step_ms = timer.get_all_durations() # ([( "matmul", 100 )], 200)
252264 """
253265
254266 def __init__ (self , max_workers : int = 2 ) -> None :
@@ -277,74 +289,70 @@ def step(self, name: str) -> None:
277289 Args:
278290 name: Label for this segment's duration
279291 """
280- # Submit polling future; chain to next event.
281292 if self ._chain_start is None :
282293 raise ValueError ("Timer must be started before calling step" )
283294
284295 stream = torch .cuda .current_stream ()
285296 end_event = torch .cuda .Event (enable_timing = True )
286297 end_event .record (stream )
287298
288- def _compute_elapsed (start_event , end_event ):
289- # Poll with backoff: starts fast (1ms), grows to cap (50ms) for mixed workloads.
290- sleep_time = 0.001 # Start at 1ms
291- while not end_event .query ():
292- time .sleep (sleep_time )
293- sleep_time = min (sleep_time * 1.5 , 0.05 ) # Backoff, cap at 50ms
294- return start_event .elapsed_time (end_event )
295-
296- future = self ._executor .submit (_compute_elapsed , self ._chain_start , end_event )
299+ future = self ._executor .submit (self ._poll_elapsed , self ._chain_start , end_event )
297300 index = len (self ._futures )
298301 self ._futures .append ((name , future , index ))
299-
300302 if len (self ._futures ) >= 5 : # clean up every 5
301303 self ._collect_completed_futures ()
302304
303305 self ._chain_start = end_event
304306
305- def _collect_completed_futures (self ) -> None :
307+ def _poll_elapsed (
308+ self , start_event : torch .cuda .Event , end_event : torch .cuda .Event
309+ ) -> float :
310+ """Compute elapsed time after polling with backoff."""
311+ # Poll until ready
312+ sleep_time = 0.001 # Start at 1ms
313+ while not end_event .query ():
314+ time .sleep (sleep_time )
315+ sleep_time = min (sleep_time * 1.5 , 0.05 ) # Backoff, cap at 50ms
316+ return start_event .elapsed_time (end_event )
317+
318+ def _collect_completed_futures (self , wait_till_done : bool = False ) -> None :
306319 """Drain done futures to avoid memory leak; update durations in submission order."""
307- completed = []
308320 still_pending = []
309321 for name , future , idx in self ._futures :
310- if future .done ():
311- try :
312- dur = future .result ()
313- completed .append ((idx , name , dur ))
314- except Exception as e :
315- raise RuntimeError (f"Timing failed for { name } : { e } " ) from e
322+ if future .done () or wait_till_done :
323+ dur = future .result ()
324+ self ._durations .append ((name , dur ))
316325 else :
317326 still_pending .append ((name , future , idx ))
318327
319- # Sort completed by submission index to preserve order
320- completed .sort (key = lambda x : x [0 ])
321- for _ , name , dur in completed :
322- self ._durations .append ((name , dur ))
323-
324328 self ._futures = still_pending
325329
326- def get_all_durations (self ) -> list [tuple [str , float ]]:
327- """Retrieve list of (name, duration) tuples in submission order after waiting for background polls to finish."""
328- # Wait and collect if pendings; return durations.
329- self ._collect_completed_futures ()
330- completed = []
331- for name , future , idx in self ._futures :
332- try :
333- dur = future .result ()
334- completed .append ((idx , name , dur ))
335- except Exception as e :
336- raise RuntimeError (f"Timing failed for { name } : { e } " ) from e
337-
338- # Sort by submission index to preserve order
339- completed .sort (key = lambda x : x [0 ])
340- for _ , name , dur in completed :
341- self ._durations .append ((name , dur ))
330+ def get_all_durations (self ) -> tuple [list [tuple [str , float ]], float ]:
331+ """Retrieve list of (step_name, duration) tuples and last step duration
332+ between tracer.stop and the last step (or start if none). Order of tuples is random.
333+ """
334+ # Final timing since last step (or start) until this function is called
335+ stop_step = f"_stop_step_{ id (self )} "
336+ self .step (stop_step )
342337
338+ # Wait on remaining futures
339+ self ._collect_completed_futures (wait_till_done = True )
343340 self ._futures .clear ()
344- return self ._durations [:]
341+
342+ # Extract stop_step_ms
343+ stop_step_ms = 0.0
344+ durations = [
345+ (name , duration ) for name , duration in self ._durations if name != stop_step
346+ ]
347+ for name , duration in self ._durations :
348+ if name == stop_step :
349+ stop_step_ms = duration
350+ break
351+
352+ return durations , stop_step_ms
345353
346354 def __del__ (self ) -> None :
347- # Fallback cleanup in finalizer; ignores errors to avoid shutdown noise.
355+ # Fallback cleanup in finalizer
348356 try :
349357 self ._executor .shutdown (wait = True )
350358 except Exception :
0 commit comments