Skip to content

Commit 5f429be

Browse files
Correlation via fft implementation
1 parent 6fc7166 commit 5f429be

File tree

2 files changed

+194
-32
lines changed

2 files changed

+194
-32
lines changed

dpnp/dpnp_iface_statistics.py

+101-24
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import math
41+
4042
import dpctl.tensor as dpt
4143
import dpctl.utils as dpu
4244
import numpy
@@ -59,6 +61,8 @@
5961
from .dpnp_utils.dpnp_utils_reduction import dpnp_wrap_reduction_call
6062
from .dpnp_utils.dpnp_utils_statistics import dpnp_cov
6163

64+
min_ = min # pylint: disable=used-before-assignment
65+
6266
__all__ = [
6367
"amax",
6468
"amin",
@@ -478,17 +482,57 @@ def _get_padding(a_size, v_size, mode):
478482
return l_pad, r_pad
479483

480484

481-
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
485+
def _choose_conv_method(a, v, rdtype):
486+
assert a.size >= v.size
487+
if rdtype == dpnp.bool:
488+
return "direct"
489+
490+
if v.size < 10**4 or a.size < 10**4:
491+
return "direct"
492+
493+
if dpnp.issubdtype(rdtype, dpnp.integer):
494+
max_a = int(dpnp.max(dpnp.abs(a)))
495+
sum_v = int(dpnp.sum(dpnp.abs(v)))
496+
max_value = int(max_a * sum_v)
497+
498+
default_float = dpnp.default_float_type(a.sycl_device)
499+
if max_value > 2 ** numpy.finfo(default_float).nmant - 1:
500+
return "direct"
501+
502+
if dpnp.issubdtype(rdtype, dpnp.number):
503+
return "fft"
504+
505+
raise ValueError(f"Unsupported dtype: {rdtype}")
506+
507+
508+
def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype):
482509
queue = a.sycl_queue
510+
device = a.sycl_device
511+
512+
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
513+
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
483514

484-
usm_type = dpu.get_coerced_usm_type([a.usm_type, v.usm_type])
485-
out_size = l_pad + r_pad + a.size - v.size + 1
515+
if supported_dtype is None:
516+
raise ValueError(
517+
f"Unsupported input types ({a.dtype}, {v.dtype}), "
518+
"and the inputs could not be coerced to any "
519+
f"supported types. List of supported types: {supported_types}"
520+
)
521+
522+
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
523+
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
524+
525+
usm_type = dpu.get_coerced_usm_type([a_casted.usm_type, v_casted.usm_type])
526+
out_size = l_pad + r_pad + a_casted.size - v_casted.size + 1
486527
out = dpnp.empty(
487-
shape=out_size, sycl_queue=queue, dtype=a.dtype, usm_type=usm_type
528+
shape=out_size,
529+
sycl_queue=queue,
530+
dtype=supported_dtype,
531+
usm_type=usm_type,
488532
)
489533

490-
a_usm = dpnp.get_usm_ndarray(a)
491-
v_usm = dpnp.get_usm_ndarray(v)
534+
a_usm = dpnp.get_usm_ndarray(a_casted)
535+
v_usm = dpnp.get_usm_ndarray(v_casted)
492536
out_usm = dpnp.get_usm_ndarray(out)
493537

494538
_manager = dpu.SequentialOrderManager[queue]
@@ -506,7 +550,30 @@ def _run_native_sliding_dot_product1d(a, v, l_pad, r_pad):
506550
return out
507551

508552

509-
def correlate(a, v, mode="valid"):
553+
def _convolve_fft(a, v, l_pad, r_pad, rtype):
554+
assert a.size >= v.size
555+
assert l_pad < v.size
556+
557+
# +1 is needed to avoid circular convolution
558+
padded_size = a.size + r_pad + 1
559+
fft_size = 2 ** math.ceil(math.log2(padded_size))
560+
561+
af = dpnp.fft.fft(a, fft_size) # pylint: disable=no-member
562+
vf = dpnp.fft.fft(v, fft_size) # pylint: disable=no-member
563+
564+
r = dpnp.fft.ifft(af * vf) # pylint: disable=no-member
565+
if dpnp.issubdtype(rtype, dpnp.floating):
566+
r = r.real
567+
elif dpnp.issubdtype(rtype, dpnp.integer) or rtype == dpnp.bool:
568+
r = r.real.round()
569+
570+
start = v.size - 1 - l_pad
571+
end = padded_size - 1
572+
573+
return r[start:end]
574+
575+
576+
def correlate(a, v, mode="valid", method="auto"):
510577
r"""
511578
Cross-correlation of two 1-dimensional sequences.
512579
@@ -531,6 +598,20 @@ def correlate(a, v, mode="valid"):
531598
is ``'valid'``, unlike :obj:`dpnp.convolve`, which uses ``'full'``.
532599
533600
Default: ``'valid'``.
601+
method : {'auto', 'direct', 'fft'}, optional
602+
`'direct'`: The correlation is determined directly from sums.
603+
604+
`'fft'`: The Fourier Transform is used to perform the calculations.
605+
This method is faster for long sequences but can have accuracy issues.
606+
607+
`'auto'`: Automatically chooses direct or Fourier method based on
608+
an estimate of which is faster.
609+
610+
Note: Use of the FFT convolution on input containing NAN or INF
611+
will lead to the entire output being NAN or INF.
612+
Use method='direct' when your input contains NAN or INF values.
613+
614+
Default: ``'auto'``.
534615
535616
Notes
536617
-----
@@ -556,7 +637,6 @@ def correlate(a, v, mode="valid"):
556637
:obj:`dpnp.convolve` : Discrete, linear convolution of two
557638
one-dimensional sequences.
558639
559-
560640
Examples
561641
--------
562642
>>> import dpnp as np
@@ -598,19 +678,14 @@ def correlate(a, v, mode="valid"):
598678
f"Received shapes: a.shape={a.shape}, v.shape={v.shape}"
599679
)
600680

601-
supported_types = statistics_ext.sliding_dot_product1d_dtypes()
681+
supported_methods = ["auto", "direct", "fft"]
682+
if method not in supported_methods:
683+
raise ValueError(
684+
f"Unknown method: {method}. Supported methods: {supported_methods}"
685+
)
602686

603687
device = a.sycl_device
604688
rdtype = result_type_for_device([a.dtype, v.dtype], device)
605-
supported_dtype = to_supported_dtypes(rdtype, supported_types, device)
606-
607-
if supported_dtype is None:
608-
raise ValueError(
609-
f"function '{correlate}' does not support input types "
610-
f"({a.dtype}, {v.dtype}), "
611-
"and the inputs could not be coerced to any "
612-
f"supported types. List of supported types: {supported_types}"
613-
)
614689

615690
if dpnp.issubdtype(v.dtype, dpnp.complexfloating):
616691
v = dpnp.conj(v)
@@ -622,13 +697,15 @@ def correlate(a, v, mode="valid"):
622697

623698
l_pad, r_pad = _get_padding(a.size, v.size, mode)
624699

625-
a_casted = dpnp.asarray(a, dtype=supported_dtype, order="C")
626-
v_casted = dpnp.asarray(v, dtype=supported_dtype, order="C")
627-
628-
if v.size > a.size:
629-
a_casted, v_casted = v_casted, a_casted
700+
if method == "auto":
701+
method = _choose_conv_method(a, v, rdtype)
630702

631-
r = _run_native_sliding_dot_product1d(a_casted, v_casted, l_pad, r_pad)
703+
if method == "direct":
704+
r = _run_native_sliding_dot_product1d(a, v, l_pad, r_pad, rdtype)
705+
elif method == "fft":
706+
r = _convolve_fft(a, v[::-1], l_pad, r_pad, rdtype)
707+
else:
708+
raise ValueError(f"Unknown method: {method}")
632709

633710
if revert:
634711
r = r[::-1]

dpnp/tests/test_statistics.py

+93-8
Original file line numberDiff line numberDiff line change
@@ -629,26 +629,104 @@ def test_corrcoef_scalar(self):
629629

630630

631631
class TestCorrelate:
632+
def setup_method(self):
633+
numpy.random.seed(0)
634+
632635
@pytest.mark.parametrize(
633636
"a, v", [([1], [1, 2, 3]), ([1, 2, 3], [1]), ([1, 2, 3], [1, 2])]
634637
)
635638
@pytest.mark.parametrize("mode", [None, "full", "valid", "same"])
636639
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
637-
def test_correlate(self, a, v, mode, dtype):
640+
@pytest.mark.parametrize("method", [None, "auto", "direct", "fft"])
641+
def test_correlate(self, a, v, mode, dtype, method):
638642
an = numpy.array(a, dtype=dtype)
639643
vn = numpy.array(v, dtype=dtype)
640644
ad = dpnp.array(an)
641645
vd = dpnp.array(vn)
642646

643-
if mode is None:
644-
expected = numpy.correlate(an, vn)
645-
result = dpnp.correlate(ad, vd)
646-
else:
647-
expected = numpy.correlate(an, vn, mode=mode)
648-
result = dpnp.correlate(ad, vd, mode=mode)
647+
dpnp_kwargs = {}
648+
numpy_kwargs = {}
649+
if mode is not None:
650+
dpnp_kwargs["mode"] = mode
651+
numpy_kwargs["mode"] = mode
652+
if method is not None:
653+
dpnp_kwargs["method"] = method
654+
655+
expected = numpy.correlate(an, vn, **numpy_kwargs)
656+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
649657

650658
assert_dtype_allclose(result, expected)
651659

660+
@pytest.mark.parametrize("a_size", [1, 100, 10000])
661+
@pytest.mark.parametrize("v_size", [1, 100, 10000])
662+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
663+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
664+
@pytest.mark.parametrize("method", ["auto", "direct", "fft"])
665+
def test_correlate_random(self, a_size, v_size, mode, dtype, method):
666+
if dtype == dpnp.bool:
667+
an = numpy.random.rand(a_size) > 0.9
668+
vn = numpy.random.rand(v_size) > 0.9
669+
else:
670+
an = (100 * numpy.random.rand(a_size)).astype(dtype)
671+
vn = (100 * numpy.random.rand(v_size)).astype(dtype)
672+
673+
if dpnp.issubdtype(dtype, dpnp.complexfloating):
674+
an = an + 1j * (100 * numpy.random.rand(a_size)).astype(dtype)
675+
vn = vn + 1j * (100 * numpy.random.rand(v_size)).astype(dtype)
676+
677+
ad = dpnp.array(an)
678+
vd = dpnp.array(vn)
679+
680+
dpnp_kwargs = {}
681+
numpy_kwargs = {}
682+
if mode is not None:
683+
dpnp_kwargs["mode"] = mode
684+
numpy_kwargs["mode"] = mode
685+
if method is not None:
686+
dpnp_kwargs["method"] = method
687+
688+
result = dpnp.correlate(ad, vd, **dpnp_kwargs)
689+
expected = numpy.correlate(an, vn, **numpy_kwargs)
690+
691+
rdtype = result.dtype
692+
if dpnp.issubdtype(rdtype, dpnp.integer):
693+
rdtype = dpnp.default_float_type(ad.device)
694+
695+
if method != "fft" and (
696+
dpnp.issubdtype(dtype, dpnp.integer) or dtype == dpnp.bool
697+
):
698+
# For 'direct' and 'auto' methods, we expect exact results for integer types
699+
assert_array_equal(result, expected)
700+
else:
701+
result = result.astype(rdtype)
702+
if method == "direct":
703+
expected = numpy.correlate(an, vn, **numpy_kwargs)
704+
# For 'direct' method we can use standard validation
705+
assert_dtype_allclose(result, expected, factor=30)
706+
else:
707+
rtol = 1e-3
708+
atol = 1e-10
709+
710+
if rdtype == dpnp.float64 or rdtype == dpnp.complex128:
711+
rtol = 1e-6
712+
atol = 1e-12
713+
elif rdtype == dpnp.bool:
714+
result = result.astype(dpnp.int32)
715+
rdtype = result.dtype
716+
717+
expected = expected.astype(rdtype)
718+
719+
diff = numpy.abs(result.asnumpy() - expected)
720+
invalid = diff > atol + rtol * numpy.abs(expected)
721+
722+
# When using the 'fft' method, we might encounter outliers.
723+
# This usually happens when the resulting array contains values close to zero.
724+
# For these outliers, the relative error can be significant.
725+
# We can tolerate a few such outliers.
726+
max_outliers = 8 if expected.size > 1 else 0
727+
if invalid.sum() > max_outliers:
728+
assert_dtype_allclose(result, expected, factor=1000)
729+
652730
def test_correlate_mode_error(self):
653731
a = dpnp.arange(5)
654732
v = dpnp.arange(3)
@@ -689,7 +767,7 @@ def test_correlate_different_sizes(self, size):
689767
vd = dpnp.array(v)
690768

691769
expected = numpy.correlate(a, v)
692-
result = dpnp.correlate(ad, vd)
770+
result = dpnp.correlate(ad, vd, method="direct")
693771

694772
assert_dtype_allclose(result, expected, factor=20)
695773

@@ -700,6 +778,13 @@ def test_correlate_another_sycl_queue(self):
700778
with pytest.raises(ValueError):
701779
dpnp.correlate(a, v)
702780

781+
def test_correlate_unkown_method(self):
782+
a = dpnp.arange(5)
783+
v = dpnp.arange(3)
784+
785+
with pytest.raises(ValueError):
786+
dpnp.correlate(a, v, method="unknown")
787+
703788

704789
@pytest.mark.parametrize(
705790
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)

0 commit comments

Comments
 (0)