diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index ce30e3026..a971019ce 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -236,11 +236,6 @@ def validate_batch_size_stages( if len(batch_size_stages) == 0: raise ValueError("Batch size stages should not be empty") - for i in range(len(batch_size_stages) - 1): - if batch_size_stages[i].batch_size >= batch_size_stages[i + 1].batch_size: - raise ValueError( - f"Batch size should be in ascending order. Got {batch_size_stages}" - ) if batch_size_stages[-1].max_iters is not None: raise ValueError( f"Batch size stages last stage should have max_iters = None, but get {batch_size_stages[-1].max_iters}"