Skip to content

Commit 44f4bb5

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Created run_pipeline API function to get benchmark results (#3237)
Summary: Pull Request resolved: #3237 Created `run_pipeline` API function that runs the pipeline on given configurations and returns the list of BenchmarkResult objects (each BenchmarkResult corresponds to a specific `rank` up to `wolrld_size`). This change is needed for future ServiceLab integration since we will be collecting pipeline benchmarks on different settings. Reviewed By: jd7-tr Differential Revision: D78941384 fbshipit-source-id: 8fc26b8e1ca3a5130508f522bb0e9e3e8cbdc9a1
1 parent ebf6bad commit 44f4bb5

File tree

1 file changed

+33
-11
lines changed

1 file changed

+33
-11
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040
TestTowerCollectionSparseNNConfig,
4141
TestTowerSparseNNConfig,
4242
)
43-
from torchrec.distributed.benchmark.benchmark_utils import benchmark_func, cmd_conf
43+
from torchrec.distributed.benchmark.benchmark_utils import (
44+
benchmark_func,
45+
BenchmarkResult,
46+
cmd_conf,
47+
)
4448
from torchrec.distributed.comm import get_local_size
4549
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
4650
from torchrec.distributed.planner import Topology
@@ -201,15 +205,7 @@ def main(
201205
table_config: EmbeddingTablesConfig,
202206
model_selection: ModelSelectionConfig,
203207
pipeline_config: PipelineConfig,
204-
model_config: Optional[
205-
Union[
206-
TestSparseNNConfig,
207-
TestTowerCollectionSparseNNConfig,
208-
TestTowerSparseNNConfig,
209-
DeepFMConfig,
210-
DLRMConfig,
211-
]
212-
] = None,
208+
model_config: Optional[BaseModelConfig] = None,
213209
) -> None:
214210
tables, weighted_tables = generate_tables(
215211
num_unweighted_features=table_config.num_unweighted_features,
@@ -254,6 +250,30 @@ def main(
254250
)
255251

256252

253+
def run_pipeline(
254+
run_option: RunOptions,
255+
table_config: EmbeddingTablesConfig,
256+
pipeline_config: PipelineConfig,
257+
model_config: BaseModelConfig,
258+
) -> List[BenchmarkResult]:
259+
260+
tables, weighted_tables = generate_tables(
261+
num_unweighted_features=table_config.num_unweighted_features,
262+
num_weighted_features=table_config.num_weighted_features,
263+
embedding_feature_dim=table_config.embedding_feature_dim,
264+
)
265+
266+
return run_multi_process_func(
267+
func=runner,
268+
world_size=run_option.world_size,
269+
tables=tables,
270+
weighted_tables=weighted_tables,
271+
run_option=run_option,
272+
model_config=model_config,
273+
pipeline_config=pipeline_config,
274+
)
275+
276+
257277
def runner(
258278
rank: int,
259279
world_size: int,
@@ -262,7 +282,7 @@ def runner(
262282
run_option: RunOptions,
263283
model_config: BaseModelConfig,
264284
pipeline_config: PipelineConfig,
265-
) -> None:
285+
) -> BenchmarkResult:
266286
# Ensure GPUs are available and we have enough of them
267287
assert (
268288
torch.cuda.is_available() and torch.cuda.device_count() >= world_size
@@ -383,6 +403,8 @@ def _func_to_benchmark(
383403
if rank == 0:
384404
print(result)
385405

406+
return result
407+
386408

387409
if __name__ == "__main__":
388410
main()

0 commit comments

Comments
 (0)