37
37
38
38
"""
39
39
40
+ import math
41
+
40
42
import dpctl .tensor as dpt
41
43
import dpctl .utils as dpu
42
44
import numpy
59
61
from .dpnp_utils .dpnp_utils_reduction import dpnp_wrap_reduction_call
60
62
from .dpnp_utils .dpnp_utils_statistics import dpnp_cov
61
63
64
+ min_ = min # pylint: disable=used-before-assignment
65
+
62
66
__all__ = [
63
67
"amax" ,
64
68
"amin" ,
@@ -478,17 +482,57 @@ def _get_padding(a_size, v_size, mode):
478
482
return l_pad , r_pad
479
483
480
484
481
- def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad ):
485
+ def _choose_conv_method (a , v , rdtype ):
486
+ assert a .size >= v .size
487
+ if rdtype == dpnp .bool :
488
+ return "direct"
489
+
490
+ if v .size < 10 ** 4 or a .size < 10 ** 4 :
491
+ return "direct"
492
+
493
+ if dpnp .issubdtype (rdtype , dpnp .integer ):
494
+ max_a = int (dpnp .max (dpnp .abs (a )))
495
+ sum_v = int (dpnp .sum (dpnp .abs (v )))
496
+ max_value = int (max_a * sum_v )
497
+
498
+ default_float = dpnp .default_float_type (a .sycl_device )
499
+ if max_value > 2 ** numpy .finfo (default_float ).nmant - 1 :
500
+ return "direct"
501
+
502
+ if dpnp .issubdtype (rdtype , dpnp .number ):
503
+ return "fft"
504
+
505
+ raise ValueError (f"Unsupported dtype: { rdtype } " )
506
+
507
+
508
+ def _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype ):
482
509
queue = a .sycl_queue
510
+ device = a .sycl_device
511
+
512
+ supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
513
+ supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
483
514
484
- usm_type = dpu .get_coerced_usm_type ([a .usm_type , v .usm_type ])
485
- out_size = l_pad + r_pad + a .size - v .size + 1
515
+ if supported_dtype is None :
516
+ raise ValueError (
517
+ f"Unsupported input types ({ a .dtype } , { v .dtype } ), "
518
+ "and the inputs could not be coerced to any "
519
+ f"supported types. List of supported types: { supported_types } "
520
+ )
521
+
522
+ a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
523
+ v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
524
+
525
+ usm_type = dpu .get_coerced_usm_type ([a_casted .usm_type , v_casted .usm_type ])
526
+ out_size = l_pad + r_pad + a_casted .size - v_casted .size + 1
486
527
out = dpnp .empty (
487
- shape = out_size , sycl_queue = queue , dtype = a .dtype , usm_type = usm_type
528
+ shape = out_size ,
529
+ sycl_queue = queue ,
530
+ dtype = supported_dtype ,
531
+ usm_type = usm_type ,
488
532
)
489
533
490
- a_usm = dpnp .get_usm_ndarray (a )
491
- v_usm = dpnp .get_usm_ndarray (v )
534
+ a_usm = dpnp .get_usm_ndarray (a_casted )
535
+ v_usm = dpnp .get_usm_ndarray (v_casted )
492
536
out_usm = dpnp .get_usm_ndarray (out )
493
537
494
538
_manager = dpu .SequentialOrderManager [queue ]
@@ -506,7 +550,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
506
550
return out
507
551
508
552
509
- def correlate (a , v , mode = "valid" ):
553
+ def _convolve_fft (a , v , l_pad , r_pad , rtype ):
554
+ assert a .size >= v .size
555
+ assert l_pad < v .size
556
+
557
+ # +1 is needed to avoid circular convolution
558
+ padded_size = a .size + r_pad + 1
559
+ fft_size = 2 ** math .ceil (math .log2 (padded_size ))
560
+
561
+ af = dpnp .fft .fft (a , fft_size ) # pylint: disable=no-member
562
+ vf = dpnp .fft .fft (v , fft_size ) # pylint: disable=no-member
563
+
564
+ r = dpnp .fft .ifft (af * vf ) # pylint: disable=no-member
565
+ if dpnp .issubdtype (rtype , dpnp .floating ):
566
+ r = r .real
567
+ elif dpnp .issubdtype (rtype , dpnp .integer ) or rtype == dpnp .bool :
568
+ r = r .real .round ()
569
+
570
+ start = v .size - 1 - l_pad
571
+ end = padded_size - 1
572
+
573
+ return r [start :end ]
574
+
575
+
576
+ def correlate (a , v , mode = "valid" , method = "auto" ):
510
577
r"""
511
578
Cross-correlation of two 1-dimensional sequences.
512
579
@@ -531,6 +598,20 @@ def correlate(a, v, mode="valid"):
531
598
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
532
599
533
600
Default: ``'valid'``.
601
+ method : {'auto', 'direct', 'fft'}, optional
602
+ `'direct'`: The correlation is determined directly from sums.
603
+
604
+ `'fft'`: The Fourier Transform is used to perform the calculations.
605
+ This method is faster for long sequences but can have accuracy issues.
606
+
607
+ `'auto'`: Automatically chooses direct or Fourier method based on
608
+ an estimate of which is faster.
609
+
610
+ Note: Use of the FFT convolution on input containing NAN or INF
611
+ will lead to the entire output being NAN or INF.
612
+ Use method='direct' when your input contains NAN or INF values.
613
+
614
+ Default: ``'auto'``.
534
615
535
616
Notes
536
617
-----
@@ -556,7 +637,6 @@ def correlate(a, v, mode="valid"):
556
637
:obj:`dpnp.convolve` : Discrete, linear convolution of two
557
638
one-dimensional sequences.
558
639
559
-
560
640
Examples
561
641
--------
562
642
>>> import dpnp as np
@@ -598,19 +678,14 @@ def correlate(a, v, mode="valid"):
598
678
f"Received shapes: a.shape={ a .shape } , v.shape={ v .shape } "
599
679
)
600
680
601
- supported_types = statistics_ext .sliding_dot_product1d_dtypes ()
681
+ supported_methods = ["auto" , "direct" , "fft" ]
682
+ if method not in supported_methods :
683
+ raise ValueError (
684
+ f"Unknown method: { method } . Supported methods: { supported_methods } "
685
+ )
602
686
603
687
device = a .sycl_device
604
688
rdtype = result_type_for_device ([a .dtype , v .dtype ], device )
605
- supported_dtype = to_supported_dtypes (rdtype , supported_types , device )
606
-
607
- if supported_dtype is None :
608
- raise ValueError (
609
- f"function '{ correlate } ' does not support input types "
610
- f"({ a .dtype } , { v .dtype } ), "
611
- "and the inputs could not be coerced to any "
612
- f"supported types. List of supported types: { supported_types } "
613
- )
614
689
615
690
if dpnp .issubdtype (v .dtype , dpnp .complexfloating ):
616
691
v = dpnp .conj (v )
@@ -622,13 +697,15 @@ def correlate(a, v, mode="valid"):
622
697
623
698
l_pad , r_pad = _get_padding (a .size , v .size , mode )
624
699
625
- a_casted = dpnp .asarray (a , dtype = supported_dtype , order = "C" )
626
- v_casted = dpnp .asarray (v , dtype = supported_dtype , order = "C" )
627
-
628
- if v .size > a .size :
629
- a_casted , v_casted = v_casted , a_casted
700
+ if method == "auto" :
701
+ method = _choose_conv_method (a , v , rdtype )
630
702
631
- r = _run_native_sliding_dot_product1d (a_casted , v_casted , l_pad , r_pad )
703
+ if method == "direct" :
704
+ r = _run_native_sliding_dot_product1d (a , v , l_pad , r_pad , rdtype )
705
+ elif method == "fft" :
706
+ r = _convolve_fft (a , v [::- 1 ], l_pad , r_pad , rdtype )
707
+ else :
708
+ raise ValueError (f"Unknown method: { method } " )
632
709
633
710
if revert :
634
711
r = r [::- 1 ]
0 commit comments