37
37
38
38
"""
39
39
40
+ import math
41
+
40
42
import dpctl .tensor as dpt
41
43
import dpctl .utils as dpu
42
44
import numpy
55
57
from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
56
58
from .dpnp_utils .dpnp_utils_statistics import dpnp_cov , dpnp_median
57
59
60
+ min_ = min # pylint: disable=used-before-assignment
61
+
58
62
__all__ = [
59
63
"amax" ,
60
64
"amin" ,
@@ -457,16 +461,55 @@ def _get_padding(a_size, v_size, mode):
457
461
return l_pad , r_pad
458
462
459
463
460
- def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad ):
464
+ def _choose_conv_method (a , v , rdtype ):
465
+ assert a .size >= v .size
466
+ if rdtype == dpnp .bool :
467
+ return "direct"
468
+
469
+ if v .size < 10 ** 4 or a .size < 10 ** 4 :
470
+ return "direct"
471
+
472
+ if dpnp .issubdtype (rdtype , dpnp .integer ):
473
+ max_a = int (dpnp .max (dpnp .abs (a )))
474
+ sum_v = int (dpnp .sum (dpnp .abs (v )))
475
+ max_value = int (max_a * sum_v )
476
+
477
+ default_float = dpnp .default_float_type (a .sycl_device )
478
+ if max_value > 2 ** numpy .finfo (default_float ).nmant - 1 :
479
+ return "direct"
480
+
481
+ if dpnp .issubdtype (rdtype , dpnp .number ):
482
+ return "fft"
483
+
484
+ raise ValueError (f"Unsupported dtype: { rdtype } " )
485
+
486
+
487
+ def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype ):
461
488
queue = a .sycl_queue
489
+ device = a .sycl_device
490
+
491
+ supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
492
+ supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
462
493
463
- usm_type = dpu .get_coerced_usm_type ([a .usm_type , v .usm_type ])
464
- out_size = l_pad + r_pad + a .size - v .size + 1
494
+ if supported_dtype is None :
495
+ raise ValueError (
496
+ f"function does not support input types "
497
+ f"({ a .dtype .name } , { v .dtype .name } ), "
498
+ "and the inputs could not be coerced to any "
499
+ f"supported types. List of supported types: "
500
+ f"{ [st .name for st in supported_types ]} "
501
+ )
502
+
503
+ a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
504
+ v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
505
+
506
+ usm_type = dpu .get_coerced_usm_type ([a_casted .usm_type , v_casted .usm_type ])
507
+ out_size = l_pad + r_pad + a_casted .size - v_casted .size + 1
465
508
# out type is the same as input type
466
509
out = dpnp .empty_like (a , shape = out_size , usm_type = usm_type )
467
510
468
- a_usm = dpnp .get_usm_ndarray (a )
469
- v_usm = dpnp .get_usm_ndarray (v )
511
+ a_usm = dpnp .get_usm_ndarray (a_casted )
512
+ v_usm = dpnp .get_usm_ndarray (v_casted )
470
513
out_usm = dpnp .get_usm_ndarray (out )
471
514
472
515
_manager = dpu .SequentialOrderManager [queue ]
@@ -484,7 +527,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
484
527
return out
485
528
486
529
487
- def correlate (a , v , mode = "valid" ):
530
+ def _convolve_fft (a , v , l_pad , r_pad , rtype ):
531
+ assert a .size >= v .size
532
+ assert l_pad < v .size
533
+
534
+ # +1 is needed to avoid circular convolution
535
+ padded_size = a .size + r_pad + 1
536
+ fft_size = 2 ** math .ceil (math .log2 (padded_size ))
537
+
538
+ af = dpnp .fft .fft (a , fft_size ) # pylint: disable=no-member
539
+ vf = dpnp .fft .fft (v , fft_size ) # pylint: disable=no-member
540
+
541
+ r = dpnp .fft .ifft (af * vf ) # pylint: disable=no-member
542
+ if dpnp .issubdtype (rtype , dpnp .floating ):
543
+ r = r .real
544
+ elif dpnp .issubdtype (rtype , dpnp .integer ) or rtype == dpnp .bool :
545
+ r = r .real .round ()
546
+
547
+ start = v .size - 1 - l_pad
548
+ end = padded_size - 1
549
+
550
+ return r [start :end ]
551
+
552
+
553
+ def correlate (a , v , mode = "valid" , method = "auto" ):
488
554
r"""
489
555
Cross-correlation of two 1-dimensional sequences.
490
556
@@ -509,6 +575,20 @@ def correlate(a, v, mode="valid"):
509
575
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
510
576
511
577
Default: ``"valid"``.
578
+ method : {'auto', 'direct', 'fft'}, optional
579
+ `'direct'`: The correlation is determined directly from sums.
580
+
581
+ `'fft'`: The Fourier Transform is used to perform the calculations.
582
+ This method is faster for long sequences but can have accuracy issues.
583
+
584
+ `'auto'`: Automatically chooses direct or Fourier method based on
585
+ an estimate of which is faster.
586
+
587
+ Note: Use of the FFT convolution on input containing NAN or INF
588
+ will lead to the entire output being NAN or INF.
589
+ Use method='direct' when your input contains NAN or INF values.
590
+
591
+ Default: ``'auto'``.
512
592
513
593
Notes
514
594
-----
@@ -576,20 +656,14 @@ def correlate(a, v, mode="valid"):
576
656
f"Received shapes: a.shape={ a .shape } , v.shape={ v .shape } "
577
657
)
578
658
579
- supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
659
+ supported_methods = ["auto" , "direct" , "fft" ]
660
+ if method not in supported_methods :
661
+ raise ValueError (
662
+ f"Unknown method: { method } . Supported methods: { supported_methods } "
663
+ )
580
664
581
665
device = a .sycl_device
582
666
rdtype = result_type_for_device ([a .dtype , v .dtype ], device )
583
- supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
584
-
585
- if supported_dtype is None :
586
- raise ValueError (
587
- f"function does not support input types "
588
- f"({ a .dtype .name } , { v .dtype .name } ), "
589
- "and the inputs could not be coerced to any "
590
- f"supported types. List of supported types: "
591
- f"{ [st .name for st in supported_types ]} "
592
- )
593
667
594
668
if dpnp .issubdtype (v .dtype , dpnp .complexfloating ):
595
669
v = dpnp .conj (v )
@@ -601,10 +675,15 @@ def correlate(a, v, mode="valid"):
601
675
602
676
l_pad , r_pad = _get_padding (a .size , v .size , mode )
603
677
604
- a_casted = dpnp . asarray ( a , dtype = supported_dtype , order = "C" )
605
- v_casted = dpnp . asarray ( v , dtype = supported_dtype , order = "C" )
678
+ if method == "auto" :
679
+ method = _choose_conv_method ( a , v , rdtype )
606
680
607
- r = _run_native_sliding_dot_product1d (a_casted , v_casted , l_pad , r_pad )
681
+ if method == "direct" :
682
+ r = _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype )
683
+ elif method == "fft" :
684
+ r = _convolve_fft (a , v [::- 1 ], l_pad , r_pad , rdtype )
685
+ else :
686
+ raise ValueError (f"Unknown method: { method } " )
608
687
609
688
if revert :
610
689
r = r [::- 1 ]
0 commit comments