Skip to content

Commit 6eb94b2

Browse files
Correlation via fft implementation
1 parent 91fd41f commit 6eb94b2

File tree

2 files changed

+192
-28
lines changed

2 files changed

+192
-28
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import math
41+
4042
import dpctl.tensor as dpt
4143
import dpctl.utils as dpu
4244
import numpy
@@ -55,6 +57,8 @@
5557
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
5658
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov, dpnp_median
5759

60+
min_ = min # pylint: disable=used-before-assignment
61+
5862
__all__ = [
5963
"amax",
6064
"amin",
@@ -457,16 +461,55 @@ def _get_padding(a_size, v_size, mode):
457461
return l_pad, r_pad
458462

459463

460-
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
464+
def _choose_conv_method(a, v, rdtype):
465+
assert a.size >= v.size
466+
if rdtype == dpnp.bool:
467+
return "direct"
468+
469+
if v.size < 10**4 or a.size < 10**4:
470+
return "direct"
471+
472+
if dpnp.issubdtype(rdtype, dpnp.integer):
473+
max_a = int(dpnp.max(dpnp.abs(a)))
474+
sum_v = int(dpnp.sum(dpnp.abs(v)))
475+
max_value = int(max_a * sum_v)
476+
477+
default_float = dpnp.default_float_type(a.sycl_device)
478+
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
479+
return "direct"
480+
481+
if dpnp.issubdtype(rdtype, dpnp.number):
482+
return "fft"
483+
484+
raise ValueError(f"Unsupported dtype: {rdtype}")
485+
486+
487+
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
461488
queue = a.sycl_queue
489+
device = a.sycl_device
490+
491+
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
492+
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
462493

463-
usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
464-
out_size = l_pad + r_pad + a.size - v.size + 1
494+
if supported_dtype is None:
495+
raise ValueError(
496+
f"function does not support input types "
497+
f"({a.dtype.name}, {v.dtype.name}), "
498+
"and the inputs could not be coerced to any "
499+
f"supported types. List of supported types: "
500+
f"{[st.name for st in supported_types]}"
501+
)
502+
503+
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
504+
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
505+
506+
usm_type = dpu.get_coerced_usm_type([a_casted.usm_type, v_casted.usm_type])
507+
out_size = l_pad + r_pad + a_casted.size - v_casted.size + 1
465508
# out type is the same as input type
466509
out = dpnp.empty_like(a, shape=out_size, usm_type=usm_type)
467510

468-
a_usm = dpnp.get_usm_ndarray(a)
469-
v_usm = dpnp.get_usm_ndarray(v)
511+
a_usm = dpnp.get_usm_ndarray(a_casted)
512+
v_usm = dpnp.get_usm_ndarray(v_casted)
470513
out_usm = dpnp.get_usm_ndarray(out)
471514

472515
_manager = dpu.SequentialOrderManager[queue]
@@ -484,7 +527,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
484527
return out
485528

486529

487-
def correlate(a, v, mode="valid"):
530+
def _convolve_fft(a, v, l_pad, r_pad, rtype):
531+
assert a.size >= v.size
532+
assert l_pad < v.size
533+
534+
# +1 is needed to avoid circular convolution
535+
padded_size = a.size + r_pad + 1
536+
fft_size = 2 ** math.ceil(math.log2(padded_size))
537+
538+
af = dpnp.fft.fft(a, fft_size) # pylint: disable=no-member
539+
vf = dpnp.fft.fft(v, fft_size) # pylint: disable=no-member
540+
541+
r = dpnp.fft.ifft(af * vf) # pylint: disable=no-member
542+
if dpnp.issubdtype(rtype, dpnp.floating):
543+
r = r.real
544+
elif dpnp.issubdtype(rtype, dpnp.integer) or rtype == dpnp.bool:
545+
r = r.real.round()
546+
547+
start = v.size - 1 - l_pad
548+
end = padded_size - 1
549+
550+
return r[start:end]
551+
552+
553+
def correlate(a, v, mode="valid", method="auto"):
488554
r"""
489555
Cross-correlation of two 1-dimensional sequences.
490556
@@ -509,6 +575,20 @@ def correlate(a, v, mode="valid"):
509575
is ``"valid"``, unlike :obj:`dpnp.convolve`, which uses ``"full"``.
510576
511577
Default: ``"valid"``.
578+
method : {'auto', 'direct', 'fft'}, optional
579+
`'direct'`: The correlation is determined directly from sums.
580+
581+
`'fft'`: The Fourier Transform is used to perform the calculations.
582+
This method is faster for long sequences but can have accuracy issues.
583+
584+
`'auto'`: Automatically chooses direct or Fourier method based on
585+
an estimate of which is faster.
586+
587+
Note: Use of the FFT convolution on input containing NAN or INF
588+
will lead to the entire output being NAN or INF.
589+
Use method='direct' when your input contains NAN or INF values.
590+
591+
Default: ``'auto'``.
512592
513593
Notes
514594
-----
@@ -576,20 +656,14 @@ def correlate(a, v, mode="valid"):
576656
f"Received shapes: a.shape={a.shape}, v.shape={v.shape}"
577657
)
578658

579-
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
659+
supported_methods = ["auto", "direct", "fft"]
660+
if method not in supported_methods:
661+
raise ValueError(
662+
f"Unknown method: {method}. Supported methods: {supported_methods}"
663+
)
580664

581665
device = a.sycl_device
582666
rdtype = result_type_for_device([a.dtype, v.dtype], device)
583-
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
584-
585-
if supported_dtype is None:
586-
raise ValueError(
587-
f"function does not support input types "
588-
f"({a.dtype.name}, {v.dtype.name}), "
589-
"and the inputs could not be coerced to any "
590-
f"supported types. List of supported types: "
591-
f"{[st.name for st in supported_types]}"
592-
)
593667

594668
if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
595669
v = dpnp.conj(v)
@@ -601,10 +675,15 @@ def correlate(a, v, mode="valid"):
601675

602676
l_pad, r_pad = _get_padding(a.size, v.size, mode)
603677

604-
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
605-
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
678+
if method == "auto":
679+
method = _choose_conv_method(a, v, rdtype)
606680

607-
r = _run_native_sliding_dot_product1d(a_casted, v_casted, l_pad, r_pad)
681+
if method == "direct":
682+
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
683+
elif method == "fft":
684+
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
685+
else:
686+
raise ValueError(f"Unknown method: {method}")
608687

609688
if revert:
610689
r = r[::-1]

dpnp/tests/test_statistics.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -629,26 +629,104 @@ def test_corrcoef_scalar(self):
629629

630630

631631
class TestCorrelate:
632+
def setup_method(self):
633+
numpy.random.seed(0)
634+
632635
@pytest.mark.parametrize(
633636
"a, v", [([1], [1, 2, 3]), ([1, 2, 3], [1]), ([1, 2, 3], [1, 2])]
634637
)
635638
@pytest.mark.parametrize("mode", [None, "full", "valid", "same"])
636639
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
637-
def test_correlate(self, a, v, mode, dtype):
640+
@pytest.mark.parametrize("method", [None, "auto", "direct", "fft"])
641+
def test_correlate(self, a, v, mode, dtype, method):
638642
an = numpy.array(a, dtype=dtype)
639643
vn = numpy.array(v, dtype=dtype)
640644
ad = dpnp.array(an)
641645
vd = dpnp.array(vn)
642646

643-
if mode is None:
644-
expected = numpy.correlate(an, vn)
645-
result = dpnp.correlate(ad, vd)
646-
else:
647-
expected = numpy.correlate(an, vn, mode=mode)
648-
result = dpnp.correlate(ad, vd, mode=mode)
647+
dpnp_kwargs = {}
648+
numpy_kwargs = {}
649+
if mode is not None:
650+
dpnp_kwargs["mode"] = mode
651+
numpy_kwargs["mode"] = mode
652+
if method is not None:
653+
dpnp_kwargs["method"] = method
654+
655+
expected = numpy.correlate(an, vn, **numpy_kwargs)
656+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
649657

650658
assert_dtype_allclose(result, expected)
651659

660+
@pytest.mark.parametrize("a_size", [1, 100, 10000])
661+
@pytest.mark.parametrize("v_size", [1, 100, 10000])
662+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
663+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
664+
@pytest.mark.parametrize("method", ["auto", "direct", "fft"])
665+
def test_correlate_random(self, a_size, v_size, mode, dtype, method):
666+
if dtype == dpnp.bool:
667+
an = numpy.random.rand(a_size) > 0.9
668+
vn = numpy.random.rand(v_size) > 0.9
669+
else:
670+
an = (100 * numpy.random.rand(a_size)).astype(dtype)
671+
vn = (100 * numpy.random.rand(v_size)).astype(dtype)
672+
673+
if dpnp.issubdtype(dtype, dpnp.complexfloating):
674+
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
675+
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)
676+
677+
ad = dpnp.array(an)
678+
vd = dpnp.array(vn)
679+
680+
dpnp_kwargs = {}
681+
numpy_kwargs = {}
682+
if mode is not None:
683+
dpnp_kwargs["mode"] = mode
684+
numpy_kwargs["mode"] = mode
685+
if method is not None:
686+
dpnp_kwargs["method"] = method
687+
688+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
689+
expected = numpy.correlate(an, vn, **numpy_kwargs)
690+
691+
rdtype = result.dtype
692+
if dpnp.issubdtype(rdtype, dpnp.integer):
693+
rdtype = dpnp.default_float_type(ad.device)
694+
695+
if method != "fft" and (
696+
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
697+
):
698+
# For 'direct' and 'auto' methods, we expect exact results for integer types
699+
assert_array_equal(result, expected)
700+
else:
701+
result = result.astype(rdtype)
702+
if method == "direct":
703+
expected = numpy.correlate(an, vn, **numpy_kwargs)
704+
# For 'direct' method we can use standard validation
705+
assert_dtype_allclose(result, expected, factor=30)
706+
else:
707+
rtol = 1e-3
708+
atol = 1e-10
709+
710+
if rdtype == dpnp.float64 or rdtype == dpnp.complex128:
711+
rtol = 1e-6
712+
atol = 1e-12
713+
elif rdtype == dpnp.bool:
714+
result = result.astype(dpnp.int32)
715+
rdtype = result.dtype
716+
717+
expected = expected.astype(rdtype)
718+
719+
diff = numpy.abs(result.asnumpy() - expected)
720+
invalid = diff > atol + rtol * numpy.abs(expected)
721+
722+
# When using the 'fft' method, we might encounter outliers.
723+
# This usually happens when the resulting array contains values close to zero.
724+
# For these outliers, the relative error can be significant.
725+
# We can tolerate a few such outliers.
726+
max_outliers = 8 if expected.size > 1 else 0
727+
if invalid.sum() > max_outliers:
728+
assert_dtype_allclose(result, expected, factor=1000)
729+
652730
def test_correlate_mode_error(self):
653731
a = dpnp.arange(5)
654732
v = dpnp.arange(3)
@@ -689,7 +767,7 @@ def test_correlate_different_sizes(self, size):
689767
vd = dpnp.array(v)
690768

691769
expected = numpy.correlate(a, v)
692-
result = dpnp.correlate(ad, vd)
770+
result = dpnp.correlate(ad, vd, method="direct")
693771

694772
assert_dtype_allclose(result, expected, factor=20)
695773

@@ -700,6 +778,13 @@ def test_correlate_another_sycl_queue(self):
700778
with pytest.raises(ValueError):
701779
dpnp.correlate(a, v)
702780

781+
def test_correlate_unkown_method(self):
782+
a = dpnp.arange(5)
783+
v = dpnp.arange(3)
784+
785+
with pytest.raises(ValueError):
786+
dpnp.correlate(a, v, method="unknown")
787+
703788

704789
@pytest.mark.parametrize(
705790
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)

0 commit comments

Comments
 (0)