Skip to content

Commit fca41a9

Browse files
committed
logic updates
1 parent fc29011 commit fca41a9

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

sklbench/benchmarks/sklearn_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
516516
bench_case, "algorithm:estimator_params", dict()
517517
)
518518
# logger.debug("estimator params: " + str(estimator_params))
519-
if "DBSCAN" in str(estimator_name):
519+
if "DBSCAN" in str(estimator_name) and get_bench_case_value(bench_case, "data:distributed_split", None) != "rank_based":
520520
if "min_samples" in estimator_params:
521521
from mpi4py import MPI
522522

sklbench/datasets/transformer.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def split_and_transform_data(bench_case, data, data_description):
109109
y_train, y_test = None, None
110110

111111
distributed_split = get_bench_case_value(bench_case, "data:distributed_split", None)
112-
knn_split_train = (
113-
"KNeighbors" in get_bench_case_value(bench_case, "algorithm:estimator", "")
114-
and int(get_bench_case_value(bench_case, "bench:mpi_params:n", 1)) > 1
115-
)
116-
if distributed_split == "rank_based" or knn_split_train:
112+
# knn_split_train = (
113+
# "KNeighbors" in get_bench_case_value(bench_case, "algorithm:estimator", "")
114+
# and int(get_bench_case_value(bench_case, "bench:mpi_params:n", 1)) > 1
115+
# )
116+
if distributed_split == "rank_based":
117117
from mpi4py import MPI
118118

119119
comm = MPI.COMM_WORLD

sklbench/utils/measurement.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def measure_time(
7979
t0 = timeit.default_timer()
8080
func_return_value = func(*args, **kwargs)
8181
t1 = timeit.default_timer()
82-
if hasattr(func.__self__, "_n_inner_iter"):
82+
if hasattr(func, "__self__") and hasattr(func.__self__, "_n_inner_iter"):
8383
inners.append(func.__self__._n_inner_iter)
8484
iters.append(func.__self__.n_iter_)
8585
if enable_itt and itt_is_available:

0 commit comments

Comments
 (0)