Skip to content

Commit 2770a88

Browse files
Fixed #1014 Fix max_matches=None bug (#1015)
* Fix max_matches=None bug * Alternate max_matches=None fix
1 parent 3077d0d commit 2770a88

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

stumpy/aamp_motifs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def aamp_motifs(
268268
m = T.shape[-1] - P.shape[-1] + 1
269269
excl_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM))
270270
if max_matches is None: # pragma: no cover
271-
max_matches = np.inf
271+
max_matches = P.shape[-1]
272272
if cutoff is None: # pragma: no cover
273273
P_copy = P.copy().astype(np.float64)
274274
P_copy[np.isinf(P_copy)] = np.nan

stumpy/motifs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def motifs(
334334
m = T.shape[-1] - P.shape[-1] + 1
335335
excl_zone = int(np.ceil(m / config.STUMPY_EXCL_ZONE_DENOM))
336336
if max_matches is None: # pragma: no cover
337-
max_matches = np.inf
337+
max_matches = P.shape[-1]
338338
if cutoff is None: # pragma: no cover
339339
P_copy = P.copy().astype(np.float64)
340340
P_copy[np.isinf(P_copy)] = np.nan

tests/test_motifs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,3 +656,30 @@ def test_motifs_with_isconstant():
656656

657657
npt.assert_almost_equal(ref_distances, comp_distance)
658658
npt.assert_almost_equal(ref_indices, comp_indices)
659+
660+
661+
def test_motifs_with_max_matches_none():
662+
T = np.random.rand(16)
663+
m = 3
664+
665+
max_motifs = 1
666+
max_matches = None
667+
max_distance = np.inf
668+
cutoff = np.inf
669+
670+
# performant
671+
mp = naive.stump(T, m, row_wise=True)
672+
comp_distance, comp_indices = motifs(
673+
T,
674+
mp[:, 0].astype(np.float64),
675+
min_neighbors=1,
676+
max_distance=max_distance,
677+
cutoff=cutoff,
678+
max_matches=max_matches,
679+
max_motifs=max_motifs,
680+
)
681+
682+
ref_len = len(T) - m + 1
683+
684+
npt.assert_(ref_len >= comp_distance.shape[1])
685+
npt.assert_(ref_len >= comp_indices.shape[1])

0 commit comments

Comments
 (0)