Skip to content

Commit 96abf2a

Browse files
atuljangrafacebook-github-bot
authored andcommitted
Remove unused memory checks to speed up compute (#2719)
Summary: Pull Request resolved: #2719 The memory checks here have non-significant overhead in every compute step as there are a lot of tensor size calls involved here. In our runs, this accounted for around 20% time spent in the rec metric compute step. Given that this is not being used anymore, let's remove this call. This diff removes the call from the metric_module. In the next set of diffs, I'll remove the argument from the callsites. Reviewed By: fegin Differential Revision: D68995122 fbshipit-source-id: 34e9d2399c3c647c41585a3f0efba2ad404cc8f4
1 parent ba4f182 commit 96abf2a

File tree

2 files changed

+3
-177
lines changed

2 files changed

+3
-177
lines changed

torchrec/metrics/metric_module.py

+3-41
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@
109109
MODEL_METRIC_LABEL: str = "model"
110110

111111

112-
MEMORY_AVG_WARNING_PERCENTAGE = 20
113-
MEMORY_AVG_WARNING_WARMUP = 100
114-
115112
MetricValue = Union[torch.Tensor, float]
116113

117114

@@ -146,7 +143,7 @@ class RecMetricModule(nn.Module):
146143
throughput_metric (Optional[ThroughputMetric]): the ThroughputMetric.
147144
state_metrics (Optional[Dict[str, StateMetric]]): the dict of StateMetrics.
148145
compute_interval_steps (int): the intervals between two compute calls in the unit of batch number
149-
memory_usage_limit_mb (float): the memory usage limit for OOM check
146+
memory_usage_limit_mb (float): [Unused] the memory usage limit for OOM check
150147
151148
Call Args:
152149
Not supported.
@@ -177,8 +174,6 @@ class RecMetricModule(nn.Module):
177174
rec_metrics: RecMetricList
178175
throughput_metric: Optional[ThroughputMetric]
179176
state_metrics: Dict[str, StateMetric]
180-
memory_usage_limit_mb: float
181-
memory_usage_mb_avg: float
182177
oom_count: int
183178
compute_count: int
184179
last_compute_time: float
@@ -195,6 +190,7 @@ def __init__(
195190
compute_interval_steps: int = 100,
196191
min_compute_interval: float = 0.0,
197192
max_compute_interval: float = float("inf"),
193+
# Unused, but needed for backwards compatibility. TODO: Remove from callsites
198194
memory_usage_limit_mb: float = 512,
199195
) -> None:
200196
super().__init__()
@@ -205,8 +201,6 @@ def __init__(
205201
self.trained_batches: int = 0
206202
self.batch_size = batch_size
207203
self.world_size = world_size
208-
self.memory_usage_limit_mb = memory_usage_limit_mb
209-
self.memory_usage_mb_avg = 0.0
210204
self.oom_count = 0
211205
self.compute_count = 0
212206

@@ -230,37 +224,6 @@ def __init__(
230224
)
231225
self.last_compute_time = -1.0
232226

233-
def get_memory_usage(self) -> int:
234-
r"""Total memory of unique RecMetric tensors in bytes"""
235-
total = {}
236-
for metric in self.rec_metrics.rec_metrics:
237-
total.update(metric.get_memory_usage())
238-
return sum(total.values())
239-
240-
def check_memory_usage(self, compute_count: int) -> None:
241-
memory_usage_mb = self.get_memory_usage() / (10**6)
242-
if memory_usage_mb > self.memory_usage_limit_mb:
243-
self.oom_count += 1
244-
logger.warning(
245-
f"MetricModule is using {memory_usage_mb}MB. "
246-
f"This is larger than the limit{self.memory_usage_limit_mb}MB. "
247-
f"This is the f{self.oom_count}th OOM."
248-
)
249-
250-
if (
251-
compute_count > MEMORY_AVG_WARNING_WARMUP
252-
and memory_usage_mb
253-
> self.memory_usage_mb_avg * ((100 + MEMORY_AVG_WARNING_PERCENTAGE) / 100)
254-
):
255-
logger.warning(
256-
f"MetricsModule is using more than {MEMORY_AVG_WARNING_PERCENTAGE}% of "
257-
f"the average memory usage. Current usage: {memory_usage_mb}MB."
258-
)
259-
260-
self.memory_usage_mb_avg = (
261-
self.memory_usage_mb_avg * (compute_count - 1) + memory_usage_mb
262-
) / compute_count
263-
264227
def _update_rec_metrics(
265228
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
266229
) -> None:
@@ -353,9 +316,8 @@ def compute(self) -> Dict[str, MetricValue]:
353316
right before logging the metrics results to the data sink.
354317
"""
355318
self.compute_count += 1
319+
ret: Dict[str, MetricValue] = {}
356320
with record_function("## RecMetricModule:compute ##"):
357-
self.check_memory_usage(self.compute_count)
358-
ret: Dict[str, MetricValue] = {}
359321
if self.rec_metrics:
360322
self._adjust_compute_interval()
361323
ret.update(self.rec_metrics.compute())

torchrec/metrics/tests/test_metric_module.py

-136
Original file line numberDiff line numberDiff line change
@@ -353,142 +353,6 @@ def test_initial_states_rank0_checkpointing(self) -> None:
353353
lc, entrypoint=self._run_trainer_initial_states_checkpointing
354354
)()
355355

356-
def test_empty_memory_usage(self) -> None:
357-
mock_optimizer = MockOptimizer()
358-
config = EmptyMetricsConfig
359-
metric_module = generate_metric_module(
360-
TestMetricModule,
361-
metrics_config=config,
362-
batch_size=128,
363-
world_size=64,
364-
my_rank=0,
365-
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
366-
device=torch.device("cpu"),
367-
)
368-
self.assertEqual(metric_module.get_memory_usage(), 0)
369-
370-
def test_ne_memory_usage(self) -> None:
371-
mock_optimizer = MockOptimizer()
372-
config = DefaultMetricsConfig
373-
metric_module = generate_metric_module(
374-
TestMetricModule,
375-
metrics_config=config,
376-
batch_size=128,
377-
world_size=64,
378-
my_rank=0,
379-
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
380-
device=torch.device("cpu"),
381-
)
382-
# Default NEMetric's dtype is
383-
# float64 (8 bytes) * 16 tensors of size 1 = 128 bytes
384-
# Tensors in NeMetricComputation:
385-
# 8 in _default, 8 specific attributes: 4 attributes, 4 window
386-
self.assertEqual(metric_module.get_memory_usage(), 128)
387-
metric_module.update(gen_test_batch(128))
388-
self.assertEqual(metric_module.get_memory_usage(), 160)
389-
390-
def test_calibration_memory_usage(self) -> None:
391-
mock_optimizer = MockOptimizer()
392-
config = dataclasses.replace(
393-
DefaultMetricsConfig,
394-
rec_metrics={
395-
RecMetricEnum.CALIBRATION: RecMetricDef(
396-
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
397-
)
398-
},
399-
)
400-
metric_module = generate_metric_module(
401-
TestMetricModule,
402-
metrics_config=config,
403-
batch_size=128,
404-
world_size=64,
405-
my_rank=0,
406-
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
407-
device=torch.device("cpu"),
408-
)
409-
# Default calibration metric dtype is
410-
# float64 (8 bytes) * 8 tensors, size 1 = 64 bytes
411-
# Tensors in CalibrationMetricComputation:
412-
# 4 in _default, 4 specific attributes: 2 attribute, 2 window
413-
self.assertEqual(metric_module.get_memory_usage(), 64)
414-
metric_module.update(gen_test_batch(128))
415-
self.assertEqual(metric_module.get_memory_usage(), 80)
416-
417-
def test_auc_memory_usage(self) -> None:
418-
mock_optimizer = MockOptimizer()
419-
config = dataclasses.replace(
420-
DefaultMetricsConfig,
421-
rec_metrics={
422-
RecMetricEnum.AUC: RecMetricDef(
423-
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
424-
)
425-
},
426-
)
427-
metric_module = generate_metric_module(
428-
TestMetricModule,
429-
metrics_config=config,
430-
batch_size=128,
431-
world_size=64,
432-
my_rank=0,
433-
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
434-
device=torch.device("cpu"),
435-
)
436-
# 3 (tensors) * 4 (float)
437-
self.assertEqual(metric_module.get_memory_usage(), 12)
438-
metric_module.update(gen_test_batch(128))
439-
# 3 (tensors) * 128 (batch_size) * 4 (float)
440-
self.assertEqual(metric_module.get_memory_usage(), 1536)
441-
442-
# Test memory usage over multiple updates does not increase unexpectedly, we don't need to force OOM as just knowing if the memory usage is increeasing how we expect is enough
443-
for _ in range(10):
444-
metric_module.update(gen_test_batch(128))
445-
446-
# 3 tensors * 128 batch size * 4 float * 11 updates
447-
self.assertEqual(metric_module.get_memory_usage(), 16896)
448-
449-
# Ensure reset frees memory correctly
450-
metric_module.reset()
451-
self.assertEqual(metric_module.get_memory_usage(), 12)
452-
453-
def test_check_memory_usage(self) -> None:
454-
mock_optimizer = MockOptimizer()
455-
config = DefaultMetricsConfig
456-
metric_module = generate_metric_module(
457-
TestMetricModule,
458-
metrics_config=config,
459-
batch_size=128,
460-
world_size=64,
461-
my_rank=0,
462-
state_metrics_mapping={StateMetricEnum.OPTIMIZERS: mock_optimizer},
463-
device=torch.device("cpu"),
464-
)
465-
metric_module.update(gen_test_batch(128))
466-
with patch("torchrec.metrics.metric_module.logger") as logger_mock:
467-
# Memory usage is fine.
468-
metric_module.memory_usage_mb_avg = 160 / (10**6)
469-
metric_module.check_memory_usage(1000)
470-
self.assertEqual(metric_module.oom_count, 0)
471-
self.assertEqual(logger_mock.warning.call_count, 0)
472-
473-
# OOM but memory usage does not exceed avg.
474-
metric_module.memory_usage_limit_mb = 0.000001
475-
metric_module.memory_usage_mb_avg = 160 / (10**6)
476-
metric_module.check_memory_usage(1000)
477-
self.assertEqual(metric_module.oom_count, 1)
478-
self.assertEqual(logger_mock.warning.call_count, 1)
479-
480-
# OOM and memory usage exceed avg but warmup is not over.
481-
metric_module.memory_usage_mb_avg = 160 / (10**6) / 10
482-
metric_module.check_memory_usage(2)
483-
self.assertEqual(metric_module.oom_count, 2)
484-
self.assertEqual(logger_mock.warning.call_count, 2)
485-
486-
# OOM and memory usage exceed avg and warmup is over.
487-
metric_module.memory_usage_mb_avg = 160 / (10**6) / 1.25
488-
metric_module.check_memory_usage(1002)
489-
self.assertEqual(metric_module.oom_count, 3)
490-
self.assertEqual(logger_mock.warning.call_count, 4)
491-
492356
def test_should_compute(self) -> None:
493357
metric_module = generate_metric_module(
494358
TestMetricModule,

0 commit comments

Comments
 (0)