Skip to content

Commit d016842

Browse files
authored
Resolve sampler benchmark variability with setting random seed (#340)
1 parent 3c2c129 commit d016842

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

benchmarks/samplers/benchmark_samplers.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,20 @@
1111
clips_at_regular_timestamps,
1212
)
1313

14+
DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
15+
DEFAULT_NUM_EXP = 30
1416

15-
def bench(f, *args, num_exp=100, warmup=0, **kwargs):
17+
18+
def bench(f, *args, num_exp, warmup=0, seed, **kwargs):
1619

1720
for _ in range(warmup):
1821
f(*args, **kwargs)
1922

2023
num_frames = None
2124
times = []
2225
for _ in range(num_exp):
26+
if seed is not None:
27+
torch.manual_seed(seed)
2328
start = perf_counter_ns()
2429
clips = f(*args, **kwargs)
2530
end = perf_counter_ns()
@@ -54,8 +59,7 @@ def sample(decoder, sampler, **kwargs):
5459
)
5560

5661

57-
def run_sampler_benchmarks(device, video):
58-
NUM_EXP = 30
62+
def run_sampler_benchmarks(device, video, num_experiments, torch_seed):
5963

6064
for num_clips in (1, 50):
6165
print("-" * 10)
@@ -68,8 +72,9 @@ def run_sampler_benchmarks(device, video):
6872
decoder,
6973
clips_at_random_indices,
7074
num_clips=num_clips,
71-
num_exp=NUM_EXP,
75+
num_exp=num_experiments,
7276
warmup=2,
77+
seed=torch_seed,
7378
)
7479
report_stats(times, num_frames, unit="ms")
7580

@@ -79,8 +84,9 @@ def run_sampler_benchmarks(device, video):
7984
decoder,
8085
clips_at_regular_indices,
8186
num_clips=num_clips,
82-
num_exp=NUM_EXP,
87+
num_exp=num_experiments,
8388
warmup=2,
89+
seed=torch_seed,
8490
)
8591
report_stats(times, num_frames, unit="ms")
8692

@@ -90,8 +96,9 @@ def run_sampler_benchmarks(device, video):
9096
decoder,
9197
clips_at_random_timestamps,
9298
num_clips=num_clips,
93-
num_exp=NUM_EXP,
99+
num_exp=num_experiments,
94100
warmup=2,
101+
seed=torch_seed,
95102
)
96103
report_stats(times, num_frames, unit="ms")
97104

@@ -102,19 +109,23 @@ def run_sampler_benchmarks(device, video):
102109
decoder,
103110
clips_at_regular_timestamps,
104111
seconds_between_clip_starts=seconds_between_clip_starts,
105-
num_exp=NUM_EXP,
112+
num_exp=num_experiments,
106113
warmup=2,
114+
seed=torch_seed,
107115
)
108116
report_stats(times, num_frames, unit="ms")
109117

110118

111119
def main():
112-
DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
113120
parser = argparse.ArgumentParser()
114121
parser.add_argument("--device", type=str, default="cpu")
115122
parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH))
123+
parser.add_argument("--num_experiments", type=int, default=DEFAULT_NUM_EXP)
124+
parser.add_argument("--torch_seed", type=int)
116125
args = parser.parse_args()
117-
run_sampler_benchmarks(args.device, args.video)
126+
run_sampler_benchmarks(
127+
args.device, args.video, args.num_experiments, args.torch_seed
128+
)
118129

119130

120131
if __name__ == "__main__":

0 commit comments

Comments
 (0)