@@ -74,7 +74,7 @@ def get_estimator(library_name: str, estimator_name: str):
74
74
def get_estimator_methods (bench_case : BenchCase ) -> Dict [str , List [str ]]:
75
75
# default estimator methods
76
76
estimator_methods = {
77
- "training" : ["fit" ],
77
+ "training" : ["partial_fit" , " fit" ],
78
78
"inference" : ["predict" , "predict_proba" , "transform" ],
79
79
}
80
80
for stage in estimator_methods .keys ():
@@ -334,7 +334,9 @@ def verify_patching(stream: io.StringIO, function_name) -> bool:
334
334
return acceleration_lines > 0 and fallback_lines == 0
335
335
336
336
337
- def create_online_function (method_instance , data_args , batch_size ):
337
+ def create_online_function (
338
+ estimator_instance , method_instance , data_args , num_batches , batch_size
339
+ ):
338
340
n_batches = data_args [0 ].shape [0 ] // batch_size
339
341
340
342
if "y" in list (inspect .signature (method_instance ).parameters ):
@@ -345,23 +347,27 @@ def ndarray_function(x, y):
345
347
x [i * batch_size : (i + 1 ) * batch_size ],
346
348
y [i * batch_size : (i + 1 ) * batch_size ],
347
349
)
350
+ estimator_instance ._onedal_finalize_fit ()
348
351
349
352
def dataframe_function (x , y ):
350
353
for i in range (n_batches ):
351
354
method_instance (
352
355
x .iloc [i * batch_size : (i + 1 ) * batch_size ],
353
356
y .iloc [i * batch_size : (i + 1 ) * batch_size ],
354
357
)
358
+ estimator_instance ._onedal_finalize_fit ()
355
359
356
360
else :
357
361
358
362
def ndarray_function (x ):
359
363
for i in range (n_batches ):
360
364
method_instance (x [i * batch_size : (i + 1 ) * batch_size ])
365
+ estimator_instance ._onedal_finalize_fit ()
361
366
362
367
def dataframe_function (x ):
363
368
for i in range (n_batches ):
364
369
method_instance (x .iloc [i * batch_size : (i + 1 ) * batch_size ])
370
+ estimator_instance ._onedal_finalize_fit ()
365
371
366
372
if "ndarray" in str (type (data_args [0 ])):
367
373
return ndarray_function
@@ -414,12 +420,28 @@ def measure_sklearn_estimator(
414
420
data_args = (x_train ,)
415
421
else :
416
422
data_args = (x_test ,)
417
- batch_size = get_bench_case_value (
418
- bench_case , f"algorithm:batch_size:{ stage } "
419
- )
420
- if batch_size is not None :
423
+
424
+ if method == "partial_fit" :
425
+ num_batches = get_bench_case_value (bench_case , "data:num_batches" )
426
+ batch_size = get_bench_case_value (bench_case , "data:batch_size" )
427
+
428
+ if batch_size is None :
429
+ if num_batches is None :
430
+ num_batches = 5
431
+ batch_size = (
432
+ data_args [0 ].shape [0 ] + num_batches - 1
433
+ ) // num_batches
434
+ if num_batches is None :
435
+ num_batches = (
436
+ data_args [0 ].shape [0 ] + batch_size - 1
437
+ ) // batch_size
438
+
421
439
method_instance = create_online_function (
422
- method_instance , data_args , batch_size
440
+ estimator_instance ,
441
+ method_instance ,
442
+ data_args ,
443
+ num_batches ,
444
+ batch_size ,
423
445
)
424
446
# daal4py model builders enabling branch
425
447
if enable_modelbuilders and stage == "inference" :
0 commit comments