11
11
clips_at_regular_timestamps ,
12
12
)
13
13
14
+ DEFAULT_VIDEO_PATH = Path (__file__ ).parent / "../../test/resources/nasa_13013.mp4"
15
+ DEFAULT_NUM_EXP = 30
14
16
15
- def bench (f , * args , num_exp = 100 , warmup = 0 , ** kwargs ):
17
+
18
+ def bench (f , * args , num_exp , warmup = 0 , seed , ** kwargs ):
16
19
17
20
for _ in range (warmup ):
18
21
f (* args , ** kwargs )
19
22
20
23
num_frames = None
21
24
times = []
22
25
for _ in range (num_exp ):
26
+ if seed is not None :
27
+ torch .manual_seed (seed )
23
28
start = perf_counter_ns ()
24
29
clips = f (* args , ** kwargs )
25
30
end = perf_counter_ns ()
@@ -54,8 +59,7 @@ def sample(decoder, sampler, **kwargs):
54
59
)
55
60
56
61
57
- def run_sampler_benchmarks (device , video ):
58
- NUM_EXP = 30
62
+ def run_sampler_benchmarks (device , video , num_experiments , torch_seed ):
59
63
60
64
for num_clips in (1 , 50 ):
61
65
print ("-" * 10 )
@@ -68,8 +72,9 @@ def run_sampler_benchmarks(device, video):
68
72
decoder ,
69
73
clips_at_random_indices ,
70
74
num_clips = num_clips ,
71
- num_exp = NUM_EXP ,
75
+ num_exp = num_experiments ,
72
76
warmup = 2 ,
77
+ seed = torch_seed ,
73
78
)
74
79
report_stats (times , num_frames , unit = "ms" )
75
80
@@ -79,8 +84,9 @@ def run_sampler_benchmarks(device, video):
79
84
decoder ,
80
85
clips_at_regular_indices ,
81
86
num_clips = num_clips ,
82
- num_exp = NUM_EXP ,
87
+ num_exp = num_experiments ,
83
88
warmup = 2 ,
89
+ seed = torch_seed ,
84
90
)
85
91
report_stats (times , num_frames , unit = "ms" )
86
92
@@ -90,8 +96,9 @@ def run_sampler_benchmarks(device, video):
90
96
decoder ,
91
97
clips_at_random_timestamps ,
92
98
num_clips = num_clips ,
93
- num_exp = NUM_EXP ,
99
+ num_exp = num_experiments ,
94
100
warmup = 2 ,
101
+ seed = torch_seed ,
95
102
)
96
103
report_stats (times , num_frames , unit = "ms" )
97
104
@@ -102,19 +109,23 @@ def run_sampler_benchmarks(device, video):
102
109
decoder ,
103
110
clips_at_regular_timestamps ,
104
111
seconds_between_clip_starts = seconds_between_clip_starts ,
105
- num_exp = NUM_EXP ,
112
+ num_exp = num_experiments ,
106
113
warmup = 2 ,
114
+ seed = torch_seed ,
107
115
)
108
116
report_stats (times , num_frames , unit = "ms" )
109
117
110
118
111
119
def main ():
112
- DEFAULT_VIDEO_PATH = Path (__file__ ).parent / "../../test/resources/nasa_13013.mp4"
113
120
parser = argparse .ArgumentParser ()
114
121
parser .add_argument ("--device" , type = str , default = "cpu" )
115
122
parser .add_argument ("--video" , type = str , default = str (DEFAULT_VIDEO_PATH ))
123
+ parser .add_argument ("--num_experiments" , type = int , default = DEFAULT_NUM_EXP )
124
+ parser .add_argument ("--torch_seed" , type = int )
116
125
args = parser .parse_args ()
117
- run_sampler_benchmarks (args .device , args .video )
126
+ run_sampler_benchmarks (
127
+ args .device , args .video , args .num_experiments , args .torch_seed
128
+ )
118
129
119
130
120
131
if __name__ == "__main__" :
0 commit comments