Skip to content

Commit ebf6bad

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Fix benchmark pipeline runner (#3236)
Summary: Pull Request resolved: #3236 Changed the return type of benchmark pipeline runner from `List[BenchmarkResult]` to single `BenchmarkResult` since runner should return only one result for the given configuration. Remove with/without`apply_jit` comparison as both configurations should call pipeline benchmarking separately. Reviewed By: jd7-tr Differential Revision: D78941101 fbshipit-source-id: cf25de1132d6d4a8365b279d83e4b13de1cb1410
1 parent 1ab1381 commit ebf6bad

File tree

1 file changed

+23
-39
lines changed

1 file changed

+23
-39
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -356,48 +356,32 @@ def _func_to_benchmark(
356356
except StopIteration:
357357
break
358358

359-
# Run comparison if apply_jit is True, otherwise run single benchmark
360-
jit_configs = (
361-
[(True, "WithJIT"), (False, "WithoutJIT")]
362-
if pipeline_config.apply_jit
363-
else [(False, "")]
359+
pipeline = generate_pipeline(
360+
pipeline_type=pipeline_config.pipeline,
361+
emb_lookup_stream=pipeline_config.emb_lookup_stream,
362+
model=sharded_model,
363+
opt=optimizer,
364+
device=ctx.device,
365+
apply_jit=pipeline_config.apply_jit,
366+
)
367+
pipeline.progress(iter(bench_inputs))
368+
369+
result = benchmark_func(
370+
name=type(pipeline).__name__,
371+
bench_inputs=bench_inputs, # pyre-ignore
372+
prof_inputs=bench_inputs, # pyre-ignore
373+
num_benchmarks=5,
374+
num_profiles=2,
375+
profile_dir=run_option.profile,
376+
world_size=run_option.world_size,
377+
func_to_benchmark=_func_to_benchmark,
378+
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
379+
rank=rank,
380+
export_stacks=run_option.export_stacks,
364381
)
365-
results = []
366-
367-
for apply_jit, jit_suffix in jit_configs:
368-
pipeline = generate_pipeline(
369-
pipeline_type=pipeline_config.pipeline,
370-
emb_lookup_stream=pipeline_config.emb_lookup_stream,
371-
model=sharded_model,
372-
opt=optimizer,
373-
device=ctx.device,
374-
apply_jit=apply_jit,
375-
)
376-
pipeline.progress(iter(bench_inputs))
377-
378-
name = (
379-
f"{type(pipeline).__name__}{jit_suffix}"
380-
if jit_suffix
381-
else type(pipeline).__name__
382-
)
383-
result = benchmark_func(
384-
name=name,
385-
bench_inputs=bench_inputs, # pyre-ignore
386-
prof_inputs=bench_inputs, # pyre-ignore
387-
num_benchmarks=5,
388-
num_profiles=2,
389-
profile_dir=run_option.profile,
390-
world_size=run_option.world_size,
391-
func_to_benchmark=_func_to_benchmark,
392-
benchmark_func_kwargs={"model": sharded_model, "pipeline": pipeline},
393-
rank=rank,
394-
export_stacks=run_option.export_stacks,
395-
)
396-
results.append(result)
397382

398383
if rank == 0:
399-
for result in results:
400-
print(result)
384+
print(result)
401385

402386

403387
if __name__ == "__main__":

0 commit comments

Comments
 (0)