Skip to content

Commit 5c8e5e2

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added CPU Memory Stats to Benchmarks (#3231)
Summary: Pull Request resolved: #3231 * Introduced functionality to capture and report peak CPU RSS memory usage statistics in the benchmarking utilities. * CPU Resident Set Size (RSS): RSS is a measure of the amount of memory occupied by a process that is held in RAM. It includes the memory allocated for the process's code, data, and stack, as well as any dynamically allocated memory. Monitoring RSS helps in identifying memory-intensive operations and optimizing resource usage to prevent memory-related issues. * Enhanced the `BenchmarkResult` class to include CPU memory metrics alongside existing GPU metrics. * Updated relevant files to ensure compatibility with the expanded `BenchmarkResult` class. * This enhancement allows users to gain insights into peak CPU memory usage during benchmarking, aiding in the identification of memory bottlenecks. Example metrics of FBGEMM operators: | Operator | CPU Runtime | GPU Runtime | GPU Peak Memory Alloc | GPU Peak Memory Reserved | CPU Peak RSS | |---------------------------------------|-------------|-------------|-----------------------|--------------------------|--------------| | **[pytorch generic] fallback** | 4.85 ms | 2.50 ms | 1.01 GB | 1.53 GB | 1.38 GB | | **[Prod] KeyedTensor.regroup** | 8.49 ms | 2.76 ms | 1.52 GB | 2.04 GB | 1.40 GB | | **[Module] KTRegroupAsDict** | 0.52 ms | 0.72 ms | 1.01 GB | 2.04 GB | 1.40 GB | | **[2 Ops] permute_multi_embs** | 3.11 ms | 1.81 ms | 1.01 GB | 2.04 GB | 1.41 GB | | **[1 Op] KT_regroup** | 2.17 ms | 1.56 ms | 1.01 GB | 2.04 GB | 1.42 GB | | **[Old Prod] permute_pooled_embs** | 4.00 ms | 2.69 ms | 1.52 GB | 2.04 GB | 1.43 GB | | **[pytorch generic] fallback_dup** | 4.84 ms | 2.41 ms | 1.01 GB | 2.04 GB | 1.44 GB | | **[Prod] KeyedTensor.regroup_dup** | 3.86 ms | 2.57 ms | 1.01 GB | 2.04 GB | 1.44 GB | | **[Module] KTRegroupAsDict_dup** | 0.15 ms | 0.76 ms | 1.01 GB | 2.04 GB | 1.44 GB | | **[2 Ops] permute_multi_embs_dup** | 0.86 ms | 1.47 ms | 1.01 GB | 2.04 GB | 1.44 GB | | **[1 Op] KT_regroup_dup** | 1.01 ms | 1.60 ms | 1.01 GB | 2.04 GB | 1.44 GB | Reviewed By: aliafzal Differential Revision: D78860097 fbshipit-source-id: d60ab6b9886d75019fafe7ba0b2645f80922ab9a
1 parent 44f4bb5 commit 5c8e5e2

File tree

3 files changed

+77
-26
lines changed

3 files changed

+77
-26
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import logging
2020
import os
21+
import resource
2122
import time
2223
import timeit
2324
from dataclasses import dataclass, fields, is_dataclass, MISSING
@@ -108,14 +109,14 @@ class CompileMode(Enum):
108109

109110

110111
@dataclass
111-
class MemoryStats:
112+
class GPUMemoryStats:
112113
rank: int
113114
malloc_retries: int
114115
max_mem_allocated_mbs: int
115116
max_mem_reserved_mbs: int
116117

117118
@classmethod
118-
def for_device(cls, rank: int) -> "MemoryStats":
119+
def for_device(cls, rank: int) -> "GPUMemoryStats":
119120
stats = torch.cuda.memory_stats(rank)
120121
alloc_retries = stats.get("num_alloc_retries", 0)
121122
max_allocated = stats.get("allocated_bytes.all.peak", 0)
@@ -131,13 +132,31 @@ def __str__(self) -> str:
131132
return f"Rank {self.rank}: retries={self.malloc_retries}, allocated={self.max_mem_allocated_mbs:7}mb, reserved={self.max_mem_reserved_mbs:7}mb"
132133

133134

135+
@dataclass
136+
class CPUMemoryStats:
137+
rank: int
138+
peak_rss_mbs: int
139+
140+
@classmethod
141+
def for_process(cls, rank: int) -> "CPUMemoryStats":
142+
# Peak RSS from resource.getrusage (in KB on CentOS/Linux)
143+
peak_rss_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
144+
peak_rss_mb = peak_rss_kb // 1024
145+
146+
return cls(rank, peak_rss_mb)
147+
148+
def __str__(self) -> str:
149+
return f"Rank {self.rank}: CPU Memory Peak RSS: {self.peak_rss_mbs/1000:.2f} GB"
150+
151+
134152
@dataclass
135153
class BenchmarkResult:
136154
"Class for holding results of benchmark runs"
137155
short_name: str
138156
gpu_elapsed_time: torch.Tensor # milliseconds
139157
cpu_elapsed_time: torch.Tensor # milliseconds
140-
mem_stats: List[MemoryStats] # memory stats per rank
158+
gpu_mem_stats: List[GPUMemoryStats] # GPU memory stats per rank
159+
cpu_mem_stats: List[CPUMemoryStats] # CPU memory stats per rank
141160
rank: int = -1
142161

143162
def __str__(self) -> str:
@@ -147,14 +166,16 @@ def __str__(self) -> str:
147166
cpu_runtime = (
148167
f"CPU Runtime (P90): {self.runtime_percentile(90, device='cpu'):.2f} ms"
149168
)
150-
if len(self.mem_stats) == 0:
151-
return f"{self.short_name: <{35}} | {gpu_runtime} | {cpu_runtime}"
152-
mem_alloc = (
153-
f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB"
154-
)
155-
mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2f} GB"
169+
cpu_mem = f"CPU Peak RSS (P90): {self.cpu_mem_percentile(90)/1000:.2f} GB"
170+
171+
if len(self.gpu_mem_stats) == 0:
172+
return (
173+
f"{self.short_name: <{35}} | {gpu_runtime} | {cpu_runtime} | {cpu_mem}"
174+
)
175+
mem_alloc = f"GPU Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB"
176+
mem_reserved = f"GPU Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2f} GB"
156177
malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50)} / {self.mem_retries(90)} / {self.mem_retries(100)}"
157-
return f"{self.short_name: <{35}} | {malloc_retries} | {gpu_runtime} | {cpu_runtime} | {mem_alloc} | {mem_reserved}"
178+
return f"{self.short_name: <{35}} | {malloc_retries} | {gpu_runtime} | {cpu_runtime} | {mem_alloc} | {mem_reserved} | {cpu_mem}"
158179

159180
def runtime_percentile(
160181
self,
@@ -199,15 +220,28 @@ def mem_retries(
199220

200221
def _mem_percentile(
201222
self,
202-
mem_selector: Callable[[MemoryStats], int],
223+
mem_selector: Callable[[GPUMemoryStats], int],
203224
percentile: int = 50,
204225
interpolation: str = "nearest",
205226
) -> torch.Tensor:
206227
mem_data = torch.tensor(
207-
[mem_selector(mem_stat) for mem_stat in self.mem_stats], dtype=torch.float
228+
[mem_selector(mem_stat) for mem_stat in self.gpu_mem_stats],
229+
dtype=torch.float,
208230
)
209231
return torch.quantile(mem_data, percentile / 100.0, interpolation=interpolation)
210232

233+
def cpu_mem_percentile(
234+
self, percentile: int = 50, interpolation: str = "nearest"
235+
) -> torch.Tensor:
236+
"""Return the CPU memory percentile for peak RSS."""
237+
cpu_mem_data = torch.tensor(
238+
[cpu_stat.peak_rss_mbs for cpu_stat in self.cpu_mem_stats],
239+
dtype=torch.float,
240+
)
241+
return torch.quantile(
242+
cpu_mem_data, percentile / 100.0, interpolation=interpolation
243+
)
244+
211245

212246
class ECWrapper(torch.nn.Module):
213247
"""
@@ -437,8 +471,11 @@ def write_report(
437471
qps_gpu = int(num_requests / avg_dur_s_gpu)
438472

439473
mem_str = ""
440-
for memory_stats in benchmark_res.mem_stats:
441-
mem_str += f"{memory_stats}\n"
474+
for gpu_memory_stats in benchmark_res.gpu_mem_stats:
475+
mem_str += f"{gpu_memory_stats}\n"
476+
477+
for cpu_memory_stats in benchmark_res.cpu_mem_stats:
478+
mem_str += f"{cpu_memory_stats}\n"
442479

443480
report_str += (
444481
f"{benchmark_res.short_name:40} "
@@ -816,13 +853,16 @@ def _run_benchmark_core(
816853
gpu_elapsed_time = cpu_elapsed_time.clone()
817854

818855
# Memory statistics collection
819-
mem_stats: List[MemoryStats] = []
856+
gpu_mem_stats: List[GPUMemoryStats] = []
857+
cpu_mem_stats = [CPUMemoryStats.for_process(rank)]
858+
820859
if device_type == "cuda":
821860
if rank == -1:
822861
for di in range(world_size):
823-
mem_stats.append(MemoryStats.for_device(di))
862+
gpu_mem_stats.append(GPUMemoryStats.for_device(di))
824863
else:
825-
mem_stats.append(MemoryStats.for_device(rank))
864+
gpu_mem_stats.append(GPUMemoryStats.for_device(rank))
865+
# CPU memory stats are collected for both GPU and CPU-only runs
826866

827867
# Optional detailed profiling
828868
if output_dir and profile_iter_fn and device_type == "cuda":
@@ -868,7 +908,8 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
868908
short_name=name,
869909
gpu_elapsed_time=gpu_elapsed_time,
870910
cpu_elapsed_time=cpu_elapsed_time,
871-
mem_stats=mem_stats,
911+
gpu_mem_stats=gpu_mem_stats,
912+
cpu_mem_stats=cpu_mem_stats,
872913
rank=rank,
873914
)
874915

@@ -1139,7 +1180,8 @@ def setUp() -> None:
11391180
res = qq.get()
11401181

11411182
benchmark_res_per_rank.append(res)
1142-
assert len(res.mem_stats) == 1
1183+
assert len(res.gpu_mem_stats) == 1
1184+
assert len(res.cpu_mem_stats) == 1
11431185

11441186
for p in processes:
11451187
p.join()
@@ -1149,13 +1191,15 @@ def setUp() -> None:
11491191
short_name=benchmark_res_per_rank[0].short_name,
11501192
gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time,
11511193
cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time,
1152-
mem_stats=[MemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
1194+
gpu_mem_stats=[GPUMemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
1195+
cpu_mem_stats=[CPUMemoryStats(rank, 0) for rank in range(world_size)],
11531196
rank=0,
11541197
)
11551198

11561199
for res in benchmark_res_per_rank:
1157-
# Each rank's BenchmarkResult contains 1 memory measurement
1158-
total_benchmark_res.mem_stats[res.rank] = res.mem_stats[0]
1200+
# Each rank's BenchmarkResult contains 1 GPU and 1 CPU memory measurement
1201+
total_benchmark_res.gpu_mem_stats[res.rank] = res.gpu_mem_stats[0]
1202+
total_benchmark_res.cpu_mem_stats[res.rank] = res.cpu_mem_stats[0]
11591203

11601204
return total_benchmark_res
11611205

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from torchrec.distributed.benchmark.benchmark_utils import (
1919
benchmark,
2020
BenchmarkResult,
21-
MemoryStats,
21+
CPUMemoryStats,
22+
GPUMemoryStats,
2223
)
2324
from torchrec.modules.regroup import KTRegroupAsDict
2425
from torchrec.sparse.jagged_tensor import (
@@ -109,7 +110,8 @@ def wrapped_func(
109110
short_name=name,
110111
gpu_elapsed_time=torch.tensor(times) * 1e3,
111112
cpu_elapsed_time=torch.tensor(times) * 1e3,
112-
mem_stats=[MemoryStats(0, 0, 0, 0)],
113+
gpu_mem_stats=[GPUMemoryStats(0, 0, 0, 0)],
114+
cpu_mem_stats=[CPUMemoryStats.for_process(0)],
113115
)
114116

115117
print(

torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
# Otherwise will get error
2121
# NotImplementedError: fbgemm::permute_1D_sparse_data: We could not find the abstract impl for this operator.
2222
from fbgemm_gpu import sparse_ops # noqa: F401, E402
23-
from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult, MemoryStats
23+
from torchrec.distributed.benchmark.benchmark_utils import (
24+
BenchmarkResult,
25+
CPUMemoryStats,
26+
GPUMemoryStats,
27+
)
2428
from torchrec.distributed.dist_data import _get_recat
2529

2630
from torchrec.distributed.test_utils.test_model import ModelInput
@@ -229,7 +233,8 @@ def benchmark_kjt(
229233
short_name=f"{test_name}-{transform_type.name}",
230234
gpu_elapsed_time=torch.tensor(times),
231235
cpu_elapsed_time=torch.tensor(times),
232-
mem_stats=[MemoryStats(0, 0, 0, 0)],
236+
gpu_mem_stats=[GPUMemoryStats(0, 0, 0, 0)],
237+
cpu_mem_stats=[CPUMemoryStats.for_process(0)],
233238
)
234239

235240
p50_runtime = result.runtime_percentile(50, interpolation="linear").item()

0 commit comments

Comments
 (0)