From fd2662f42f82aa48830f5e38bf19bd3998230327 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Fri, 13 Dec 2024 04:04:08 -0800 Subject: [PATCH 1/2] Add runner argument to describe datasets --- sklbench/runner/arguments.py | 6 ++++++ sklbench/runner/implementation.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sklbench/runner/arguments.py b/sklbench/runner/arguments.py index 1ba47daa..3f114dd0 100644 --- a/sklbench/runner/arguments.py +++ b/sklbench/runner/arguments.py @@ -130,6 +130,12 @@ def add_runner_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentPa action="store_true", help="Load all requested datasets in parallel before running benchmarks.", ) + parser.add_argument( + "--describe-datasets", + default=False, + action="store_true", + help="Load all requested datasets in parallel and show their parameters.", + ) # workflow control parser.add_argument( "--exit-on-error", diff --git a/sklbench/runner/implementation.py b/sklbench/runner/implementation.py index 2375e4b7..7e546b4f 100644 --- a/sklbench/runner/implementation.py +++ b/sklbench/runner/implementation.py @@ -16,7 +16,9 @@ import argparse +import gc import json +import sys from multiprocessing import Pool from typing import Dict, List, Tuple, Union @@ -94,7 +96,7 @@ def run_benchmarks(args: argparse.Namespace) -> int: bench_cases = early_filtering(bench_cases, param_filters) # prefetch datasets - if args.prefetch_datasets: + if args.prefetch_datasets or args.describe_datasets: # trick: get unique dataset names only to avoid loading of same dataset # by different cases/processes dataset_cases = {get_data_name(case): case for case in bench_cases} @@ -102,7 +104,14 @@ def run_benchmarks(args: argparse.Namespace) -> int: n_proc = min([16, cpu_count(), len(dataset_cases)]) logger.info(f"Prefetching datasets with {n_proc} processes") with Pool(n_proc) as pool: - pool.map(load_data, dataset_cases.values()) + datasets = pool.map(load_data, dataset_cases.values()) + if args.describe_datasets: + for ((data, data_description), data_name) in zip(datasets, dataset_cases.keys()): + print(f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}") + sys.exit(0) + # free memory used by prefetched datasets + del datasets + gc.collect() # run bench_cases return_code, result = call_benchmarks( From 4412505a5a8d0d5778fb3a35af5b26943813d58e Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Fri, 13 Dec 2024 04:05:22 -0800 Subject: [PATCH 2/2] Linting --- sklbench/runner/implementation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklbench/runner/implementation.py b/sklbench/runner/implementation.py index 7e546b4f..47b10962 100644 --- a/sklbench/runner/implementation.py +++ b/sklbench/runner/implementation.py @@ -106,8 +106,12 @@ def run_benchmarks(args: argparse.Namespace) -> int: with Pool(n_proc) as pool: datasets = pool.map(load_data, dataset_cases.values()) if args.describe_datasets: - for ((data, data_description), data_name) in zip(datasets, dataset_cases.keys()): - print(f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}") + for (data, data_description), data_name in zip( + datasets, dataset_cases.keys() + ): + print( + f"{data_name}:\n\tshape: {data['x'].shape}\n\tparameters: {data_description}" + ) sys.exit(0) # free memory used by prefetched datasets del datasets