@@ -353,142 +353,6 @@ def test_initial_states_rank0_checkpointing(self) -> None:
353
353
lc , entrypoint = self ._run_trainer_initial_states_checkpointing
354
354
)()
355
355
356
- def test_empty_memory_usage (self ) -> None :
357
- mock_optimizer = MockOptimizer ()
358
- config = EmptyMetricsConfig
359
- metric_module = generate_metric_module (
360
- TestMetricModule ,
361
- metrics_config = config ,
362
- batch_size = 128 ,
363
- world_size = 64 ,
364
- my_rank = 0 ,
365
- state_metrics_mapping = {StateMetricEnum .OPTIMIZERS : mock_optimizer },
366
- device = torch .device ("cpu" ),
367
- )
368
- self .assertEqual (metric_module .get_memory_usage (), 0 )
369
-
370
- def test_ne_memory_usage (self ) -> None :
371
- mock_optimizer = MockOptimizer ()
372
- config = DefaultMetricsConfig
373
- metric_module = generate_metric_module (
374
- TestMetricModule ,
375
- metrics_config = config ,
376
- batch_size = 128 ,
377
- world_size = 64 ,
378
- my_rank = 0 ,
379
- state_metrics_mapping = {StateMetricEnum .OPTIMIZERS : mock_optimizer },
380
- device = torch .device ("cpu" ),
381
- )
382
- # Default NEMetric's dtype is
383
- # float64 (8 bytes) * 16 tensors of size 1 = 128 bytes
384
- # Tensors in NeMetricComputation:
385
- # 8 in _default, 8 specific attributes: 4 attributes, 4 window
386
- self .assertEqual (metric_module .get_memory_usage (), 128 )
387
- metric_module .update (gen_test_batch (128 ))
388
- self .assertEqual (metric_module .get_memory_usage (), 160 )
389
-
390
- def test_calibration_memory_usage (self ) -> None :
391
- mock_optimizer = MockOptimizer ()
392
- config = dataclasses .replace (
393
- DefaultMetricsConfig ,
394
- rec_metrics = {
395
- RecMetricEnum .CALIBRATION : RecMetricDef (
396
- rec_tasks = [DefaultTaskInfo ], window_size = _DEFAULT_WINDOW_SIZE
397
- )
398
- },
399
- )
400
- metric_module = generate_metric_module (
401
- TestMetricModule ,
402
- metrics_config = config ,
403
- batch_size = 128 ,
404
- world_size = 64 ,
405
- my_rank = 0 ,
406
- state_metrics_mapping = {StateMetricEnum .OPTIMIZERS : mock_optimizer },
407
- device = torch .device ("cpu" ),
408
- )
409
- # Default calibration metric dtype is
410
- # float64 (8 bytes) * 8 tensors, size 1 = 64 bytes
411
- # Tensors in CalibrationMetricComputation:
412
- # 4 in _default, 4 specific attributes: 2 attribute, 2 window
413
- self .assertEqual (metric_module .get_memory_usage (), 64 )
414
- metric_module .update (gen_test_batch (128 ))
415
- self .assertEqual (metric_module .get_memory_usage (), 80 )
416
-
417
- def test_auc_memory_usage (self ) -> None :
418
- mock_optimizer = MockOptimizer ()
419
- config = dataclasses .replace (
420
- DefaultMetricsConfig ,
421
- rec_metrics = {
422
- RecMetricEnum .AUC : RecMetricDef (
423
- rec_tasks = [DefaultTaskInfo ], window_size = _DEFAULT_WINDOW_SIZE
424
- )
425
- },
426
- )
427
- metric_module = generate_metric_module (
428
- TestMetricModule ,
429
- metrics_config = config ,
430
- batch_size = 128 ,
431
- world_size = 64 ,
432
- my_rank = 0 ,
433
- state_metrics_mapping = {StateMetricEnum .OPTIMIZERS : mock_optimizer },
434
- device = torch .device ("cpu" ),
435
- )
436
- # 3 (tensors) * 4 (float)
437
- self .assertEqual (metric_module .get_memory_usage (), 12 )
438
- metric_module .update (gen_test_batch (128 ))
439
- # 3 (tensors) * 128 (batch_size) * 4 (float)
440
- self .assertEqual (metric_module .get_memory_usage (), 1536 )
441
-
442
- # Test memory usage over multiple updates does not increase unexpectedly, we don't need to force OOM as just knowing if the memory usage is increeasing how we expect is enough
443
- for _ in range (10 ):
444
- metric_module .update (gen_test_batch (128 ))
445
-
446
- # 3 tensors * 128 batch size * 4 float * 11 updates
447
- self .assertEqual (metric_module .get_memory_usage (), 16896 )
448
-
449
- # Ensure reset frees memory correctly
450
- metric_module .reset ()
451
- self .assertEqual (metric_module .get_memory_usage (), 12 )
452
-
453
- def test_check_memory_usage (self ) -> None :
454
- mock_optimizer = MockOptimizer ()
455
- config = DefaultMetricsConfig
456
- metric_module = generate_metric_module (
457
- TestMetricModule ,
458
- metrics_config = config ,
459
- batch_size = 128 ,
460
- world_size = 64 ,
461
- my_rank = 0 ,
462
- state_metrics_mapping = {StateMetricEnum .OPTIMIZERS : mock_optimizer },
463
- device = torch .device ("cpu" ),
464
- )
465
- metric_module .update (gen_test_batch (128 ))
466
- with patch ("torchrec.metrics.metric_module.logger" ) as logger_mock :
467
- # Memory usage is fine.
468
- metric_module .memory_usage_mb_avg = 160 / (10 ** 6 )
469
- metric_module .check_memory_usage (1000 )
470
- self .assertEqual (metric_module .oom_count , 0 )
471
- self .assertEqual (logger_mock .warning .call_count , 0 )
472
-
473
- # OOM but memory usage does not exceed avg.
474
- metric_module .memory_usage_limit_mb = 0.000001
475
- metric_module .memory_usage_mb_avg = 160 / (10 ** 6 )
476
- metric_module .check_memory_usage (1000 )
477
- self .assertEqual (metric_module .oom_count , 1 )
478
- self .assertEqual (logger_mock .warning .call_count , 1 )
479
-
480
- # OOM and memory usage exceed avg but warmup is not over.
481
- metric_module .memory_usage_mb_avg = 160 / (10 ** 6 ) / 10
482
- metric_module .check_memory_usage (2 )
483
- self .assertEqual (metric_module .oom_count , 2 )
484
- self .assertEqual (logger_mock .warning .call_count , 2 )
485
-
486
- # OOM and memory usage exceed avg and warmup is over.
487
- metric_module .memory_usage_mb_avg = 160 / (10 ** 6 ) / 1.25
488
- metric_module .check_memory_usage (1002 )
489
- self .assertEqual (metric_module .oom_count , 3 )
490
- self .assertEqual (logger_mock .warning .call_count , 4 )
491
-
492
356
def test_should_compute (self ) -> None :
493
357
metric_module = generate_metric_module (
494
358
TestMetricModule ,
0 commit comments