Skip to content

Commit 794c276

Browse files
committed
Merge branch 'dev/eglaser-large-scale-incremental' into dev/eglaser-large-scale-incr
2 parents 2edb597 + 7aa42a3 commit 794c276

File tree

3 files changed

+29
-35
lines changed

3 files changed

+29
-35
lines changed

configs/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Configs have the three highest parameter keys:
117117
|:---------------|:--------------|:--------|:------------|
118118
| `algorithm`:`estimator` | None | | Name of measured estimator. |
119119
| `algorithm`:`estimator_params` | Empty `dict` | | Parameters for estimator constructor. |
120+
| `algorithm`:`num_batches`:`training` | 5 | | Number of batches to benchmark `partial_fit` function, using batches the size of number of samples specified (not samples divided by `num_batches`). For incremental estimators only. |
120121
| `algorithm`:`online_inference_mode` | False | | Enables online mode for inference methods of estimator (separate call for each sample). |
121122
| `algorithm`:`sklearn_context` | None | | Parameters for sklearn `config_context` used over estimator. |
122123
| `algorithm`:`sklearnex_context` | None | | Parameters for sklearnex `config_context` used over estimator. Updated by `sklearn_context` if set. |

sklbench/benchmarks/sklearn_estimator.py

+26-31
Original file line numberDiff line numberDiff line change
@@ -324,41 +324,33 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
324324
return acceleration_lines > 0 and fallback_lines == 0
325325

326326

327-
def create_online_function(
328-
estimator_instance, method_instance, data_args, num_batches, batch_size
329-
):
327+
def create_online_function(estimator_instance, method_instance, data_args, num_batches):
330328

331329
if "y" in list(inspect.signature(method_instance).parameters):
332330

333331
def ndarray_function(x, y):
334332
for i in range(num_batches):
335-
method_instance(
336-
x[i * batch_size : (i + 1) * batch_size],
337-
y[i * batch_size : (i + 1) * batch_size],
338-
)
333+
method_instance(x, y)
339334
if hasattr(estimator_instance, "_onedal_finalize_fit"):
340335
estimator_instance._onedal_finalize_fit()
341336

342337
def dataframe_function(x, y):
343338
for i in range(num_batches):
344-
method_instance(
345-
x.iloc[i * batch_size : (i + 1) * batch_size],
346-
y.iloc[i * batch_size : (i + 1) * batch_size],
347-
)
339+
method_instance(x, y)
348340
if hasattr(estimator_instance, "_onedal_finalize_fit"):
349341
estimator_instance._onedal_finalize_fit()
350342

351343
else:
352344

353345
def ndarray_function(x):
354346
for i in range(num_batches):
355-
method_instance(x[i * batch_size : (i + 1) * batch_size])
347+
method_instance(x)
356348
if hasattr(estimator_instance, "_onedal_finalize_fit"):
357349
estimator_instance._onedal_finalize_fit()
358350

359351
def dataframe_function(x):
360352
for i in range(num_batches):
361-
method_instance(x.iloc[i * batch_size : (i + 1) * batch_size])
353+
method_instance(x)
362354
if hasattr(estimator_instance, "_onedal_finalize_fit"):
363355
estimator_instance._onedal_finalize_fit()
364356

@@ -413,28 +405,20 @@ def measure_sklearn_estimator(
413405
data_args = (x_train,)
414406
else:
415407
data_args = (x_test,)
408+
batch_size = get_bench_case_value(
409+
bench_case, f"algorithm:batch_size:{stage}"
410+
)
416411

417412
if method == "partial_fit":
418-
num_batches = get_bench_case_value(bench_case, "data:num_batches")
419-
batch_size = get_bench_case_value(bench_case, "data:batch_size")
420-
421-
if batch_size is None:
422-
if num_batches is None:
423-
num_batches = 5
424-
batch_size = (
425-
data_args[0].shape[0] + num_batches - 1
426-
) // num_batches
427-
if num_batches is None:
428-
num_batches = (
429-
data_args[0].shape[0] + batch_size - 1
430-
) // batch_size
413+
num_batches = get_bench_case_value(
414+
bench_case, f"algorithm:num_batches:{stage}", 5
415+
)
431416

432417
method_instance = create_online_function(
433418
estimator_instance,
434419
method_instance,
435420
data_args,
436-
num_batches,
437-
batch_size,
421+
num_batches
438422
)
439423
# daal4py model builders enabling branch
440424
if enable_modelbuilders and stage == "inference":
@@ -452,6 +436,10 @@ def measure_sklearn_estimator(
452436
metrics[method]["box filter mean[ms]"],
453437
metrics[method]["box filter std[ms]"],
454438
) = measure_case(bench_case, method_instance, *data_args)
439+
if batch_size is not None:
440+
metrics[method]["throughput[samples/ms]"] = (
441+
(data_args[0].shape[0] // batch_size) * batch_size
442+
) / metrics[method]["time[ms]"]
455443
if ensure_sklearnex_patching:
456444
full_method_name = f"{estimator_class.__name__}.{method}"
457445
sklearnex_logging_stream.seek(0)
@@ -559,9 +547,16 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
559547
for stage in estimator_methods.keys():
560548
data_descs[stage].update(
561549
{
562-
"batch_size": get_bench_case_value(
563-
bench_case, f"algorithm:batch_size:{stage}"
564-
)
550+
key: val
551+
for key, val in {
552+
"batch_size": get_bench_case_value(
553+
bench_case, f"algorithm:batch_size:{stage}"
554+
),
555+
"num_batches": get_bench_case_value(
556+
bench_case, f"algorithm:num_batches:{stage}"
557+
)
558+
}.items()
559+
if val is not None
565560
}
566561
)
567562
if "n_classes" in data_description:

sklbench/report/implementation.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
"order",
9898
"n_classes",
9999
"n_clusters",
100+
"num_batches",
100101
"batch_size",
101102
]
102103

@@ -262,10 +263,7 @@ def get_summary_from_df(df: pd.DataFrame, df_name: str) -> pd.DataFrame:
262263
# only relative improvements are included in summary currently
263264
if len(column) > 1 and column[1] == f"{metric_name} relative improvement":
264265
metric_columns.append(column)
265-
if metric_columns:
266-
summary = df[metric_columns].aggregate(geomean_wrapper, axis=0).to_frame().T
267-
else:
268-
summary = pd.DataFrame()
266+
summary = df[metric_columns].aggregate(geomean_wrapper, axis=0).to_frame().T
269267
summary.index = pd.Index([df_name])
270268
return summary
271269

0 commit comments

Comments
 (0)