From 199798db209dd001d58a349696d6b4f70ac882a0 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 21 Feb 2025 17:42:59 -0800 Subject: [PATCH] Added more validations and unit tests Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../processor/CompoundTopDocs.java | 75 ++++++ .../L2ScoreNormalizationTechnique.java | 10 + .../MinMaxScoreNormalizationTechnique.java | 201 ++++++++------ .../ScoreNormalizationFactory.java | 2 +- .../processor/CompoundTopDocsTests.java | 150 +++++++++++ ...inMaxScoreNormalizationTechniqueTests.java | 254 ++++++++++++++++-- .../ScoreNormalizationFactoryTests.java | 34 +++ .../query/HybridQueryExplainIT.java | 214 ++++++++++----- 9 files changed, 767 insertions(+), 174 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc16369c..180989f05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195)) ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) ### Bug Fixes diff --git a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java index 11a0c7ee0..986a8f261 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java @@ -13,6 +13,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -150,4 +151,78 @@ private ScoreDoc deepCopyScoreDoc(final ScoreDoc scoreDoc, final boolean isSortE FieldDoc fieldDoc = (FieldDoc) scoreDoc; return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex); } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + CompoundTopDocs that = (CompoundTopDocs) other; + + if (this.topDocs.size() != that.topDocs.size()) { + return false; + } + for (int i = 0; i < topDocs.size(); i++) { + TopDocs thisTopDoc = this.topDocs.get(i); + TopDocs thatTopDoc = that.topDocs.get(i); + if ((thisTopDoc == null) != (thatTopDoc == null)) { + return false; + } + if (thisTopDoc == null) { + continue; + } + if (!Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits)) { + return false; + } + if (!compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs)) { + return false; + } + } + return Objects.equals(totalHits, that.totalHits) && Objects.equals(searchShard, that.searchShard); + } + + private boolean compareScoreDocs(ScoreDoc[] first, ScoreDoc[] second) { + if (first.length != second.length) { + return false; + } + + for (int i = 0; i < first.length; i++) { + ScoreDoc firstDoc = first[i]; + ScoreDoc secondDoc = second[i]; + if ((firstDoc == null) != (secondDoc == null)) { + return false; + } + if (firstDoc == null) { + continue; + } + if (firstDoc.doc != secondDoc.doc || Float.compare(firstDoc.score, secondDoc.score) != 0) { + return false; + } + if (firstDoc instanceof FieldDoc != secondDoc instanceof FieldDoc) { + return false; + } + if (firstDoc instanceof FieldDoc firstFieldDoc) { + FieldDoc secondFieldDoc = (FieldDoc) secondDoc; + if (!Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields)) { + return false; + } + } + } + return true; + } + + @Override + public int hashCode() { + int result = Objects.hash(totalHits, searchShard); + for (TopDocs topDoc : topDocs) { + result = 31 * result + topDoc.totalHits.hashCode(); + for (ScoreDoc scoreDoc : topDoc.scoreDocs) { + result = 31 * result + Float.floatToIntBits(scoreDoc.score); + result = 31 * result + scoreDoc.doc; + if (scoreDoc instanceof FieldDoc fieldDoc && fieldDoc.fields != null) { + result = 31 * result + Arrays.deepHashCode(fieldDoc.fields); + } + } + } + return result; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 1208ffe77..bc81f91b5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -32,6 +32,16 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu public static final String TECHNIQUE_NAME = "l2"; private static final float MIN_SCORE = 0.0f; + public L2ScoreNormalizationTechnique() { + this(Map.of()); + } + + public L2ScoreNormalizationTechnique(final Map params) { + if (Objects.nonNull(params) && !params.isEmpty()) { + throw new IllegalArgumentException("unrecognized parameters in normalization technique"); + } + } + /** * L2 normalization method. * n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index d9a009a87..edff967ff 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -11,6 +11,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import lombok.AllArgsConstructor; @@ -42,14 +43,18 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech public static final String TECHNIQUE_NAME = "min_max"; protected static final float MIN_SCORE = 0.001f; private static final float SINGLE_RESULT_SCORE = 1.0f; - private final List> lowerBounds; + private static final String PARAM_NAME_LOWER_BOUNDS = "lower_bounds"; + private static final String PARAM_NAME_LOWER_BOUND_MODE = "mode"; + private static final String PARAM_NAME_LOWER_BOUND_MIN_SCORE = "min_score"; + + private final Optional>> lowerBoundsOptional; public MinMaxScoreNormalizationTechnique() { this(Map.of()); } public MinMaxScoreNormalizationTechnique(final Map params) { - lowerBounds = getLowerBounds(params); + lowerBoundsOptional = getLowerBounds(params); } /** @@ -69,8 +74,14 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { continue; } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - if (Objects.nonNull(lowerBounds) && !lowerBounds.isEmpty() && lowerBounds.size() != topDocsPerSubQuery.size()) { - throw new IllegalArgumentException("lower bounds size should be same as number of sub queries"); + if (isLowerBoundsAndSubQueriesCountMismatched(topDocsPerSubQuery)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "expected lower bounds array to contain %d elements matching the number of sub-queries, but found a mismatch", + topDocsPerSubQuery.size() + ) + ); } for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); @@ -87,14 +98,16 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { } } - private LowerBound getLowerBound(int j) { - LowerBound lowerBound; - if (Objects.isNull(lowerBounds) || lowerBounds.isEmpty()) { - lowerBound = new LowerBound(); - } else { - lowerBound = new LowerBound(true, lowerBounds.get(j).getLeft(), lowerBounds.get(j).getRight()); - } - return lowerBound; + private boolean isLowerBoundsAndSubQueriesCountMismatched(List topDocsPerSubQuery) { + return lowerBoundsOptional.isPresent() + && !topDocsPerSubQuery.isEmpty() + && lowerBoundsOptional.get().size() != topDocsPerSubQuery.size(); + } + + private LowerBound getLowerBound(int subQueryIndex) { + return lowerBoundsOptional.map( + pairs -> new LowerBound(true, pairs.get(subQueryIndex).getLeft(), pairs.get(subQueryIndex).getRight()) + ).orElseGet(LowerBound::new); } private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { @@ -108,7 +121,12 @@ private MinMaxScores getMinMaxScoresResult(final List queryTopD @Override public String describe() { - return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); + return lowerBoundsOptional.map(lb -> { + String lowerBounds = lb.stream() + .map(pair -> String.format(Locale.ROOT, "(%s, %s)", pair.getLeft(), pair.getRight())) + .collect(Collectors.joining(", ", "[", "]")); + return String.format(Locale.ROOT, "%s, lower bounds %s", TECHNIQUE_NAME, lowerBounds); + }).orElse(String.format(Locale.ROOT, "%s", TECHNIQUE_NAME)); } @Override @@ -187,10 +205,6 @@ private float[] getMinScores(final List queryTopDocs, final int } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { - // LowerBound lowerBound = getLowerBound(j); - // we need to compute actual min score for everything except clipping. For clipping we have to use - // lower bound min_score, it's passed as parameter. If we skip for clipping we can save some CPU cycles. - // if (!lowerBound.isEnabled() || lowerBound.getMode() != Mode.CLIP) { minScores[j] = Math.min( minScores[j], Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) @@ -198,21 +212,19 @@ private float[] getMinScores(final List queryTopDocs, final int .min(Float::compare) .orElse(Float.MAX_VALUE) ); - // } } } return minScores; } - private float normalizeSingleScore(final float score, final float minScore, final float maxScore, LowerBound lowerBound) { + private float normalizeSingleScore(final float score, final float minScore, final float maxScore, final LowerBound lowerBound) { // edge case when there is only one score and min and max scores are same if (Floats.compare(maxScore, minScore) == 0 && Floats.compare(maxScore, score) == 0) { return SINGLE_RESULT_SCORE; } if (!lowerBound.isEnabled()) { - return Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore()); + return LowerBound.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore()); } - return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore()); } @@ -226,15 +238,14 @@ private class MinMaxScores { float[] maxScoresPerSubquery; } - private List> getLowerBounds(final Map params) { - List> lowerBounds = new ArrayList<>(); - - // Early return if params is null or doesn't contain lower_bounds - if (Objects.isNull(params) || !params.containsKey("lower_bounds")) { - return lowerBounds; + private Optional>> getLowerBounds(final Map params) { + if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_LOWER_BOUNDS)) { + return Optional.empty(); } - Object lowerBoundsObj = params.get("lower_bounds"); + List> lowerBounds = new ArrayList<>(); + + Object lowerBoundsObj = params.get(PARAM_NAME_LOWER_BOUNDS); if (!(lowerBoundsObj instanceof List lowerBoundsParams)) { throw new IllegalArgumentException("lower_bounds must be a List"); } @@ -259,8 +270,17 @@ private List> getLowerBounds(final Map params) Map lowerBound = (Map) boundObj; try { - Mode mode = Mode.fromString(lowerBound.get("mode").toString()); - float minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score"))); + LowerBound.Mode mode = LowerBound.Mode.fromString( + Objects.isNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE)) + ? "" + : lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE).toString() + ); + float minScore; + if (Objects.isNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))) { + minScore = LowerBound.DEFAULT_LOWER_BOUND_SCORE; + } else { + minScore = Float.parseFloat(String.valueOf(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))); + } Validate.isTrue( minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE, @@ -271,25 +291,25 @@ private List> getLowerBounds(final Map params) lowerBounds.add(ImmutablePair.of(mode, minScore)); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid format for min_score: must be a valid float value", e); + throw new IllegalArgumentException("invalid format for min_score: must be a valid float value", e); } } - return lowerBounds; + return Optional.of(lowerBounds); } /** * Result class to hold lower bound for each sub query */ @Getter - private static class LowerBound { + public static class LowerBound { static final float MIN_LOWER_BOUND_SCORE = -10_000f; static final float MAX_LOWER_BOUND_SCORE = 10_000f; static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f; - boolean enabled; - Mode mode; - float minScore; + private final boolean enabled; + private final Mode mode; + private final float minScore; LowerBound() { this(false, Mode.DEFAULT, DEFAULT_LOWER_BOUND_SCORE); @@ -300,56 +320,79 @@ private static class LowerBound { this.mode = mode; this.minScore = minScore; } - } - protected enum Mode { - APPLY { - @Override - public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - if (maxScore < lowerBoundScore) { - return (score - minScore) / (maxScore - minScore); - } else if (score < lowerBoundScore) { - return score / (maxScore - score); + /** + * Enum for normalization mode + */ + protected enum Mode { + APPLY { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + // if we apply the lower bound this mean we use actual score in case it's less then the lower bound min score + // same applied to case when actual max_score is less than lower bound min score + if (maxScore < lowerBoundScore || score < lowerBoundScore) { + return (score - minScore) / (maxScore - minScore); + } + return (score - lowerBoundScore) / (maxScore - lowerBoundScore); } - return (score - lowerBoundScore) / (maxScore - lowerBoundScore); - } - }, - CLIP { - @Override - public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - if (score < minScore) { - return lowerBoundScore / (maxScore - lowerBoundScore); + }, + CLIP { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + // apply clipping, return lower bound min score if score is less than min score. This effectively means 0 after + // normalization + if (score < minScore) { + return 0.0f; + } + if (maxScore < lowerBoundScore) { + return (score - minScore) / (maxScore - minScore); + } + return (score - lowerBoundScore) / (maxScore - lowerBoundScore); + } + }, + IGNORE { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + // ignore lower bound logic and do raw min-max normalization using actual scores + float normalizedScore = (score - minScore) / (maxScore - minScore); + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } + }; + + public static final Mode DEFAULT = APPLY; + // set of all valid values for mode + public static final String VALID_VALUES = Arrays.stream(values()) + .map(mode -> mode.name().toLowerCase(Locale.ROOT)) + .collect(Collectors.joining(", ")); + + /** + * Get mode from string value + * @param value string value of mode + * @return mode + * @throws IllegalArgumentException if mode is not valid + */ + public static Mode fromString(String value) { + if (Objects.isNull(value)) { + throw new IllegalArgumentException("mode value cannot be null or empty"); + } + if (value.trim().isEmpty()) { + return DEFAULT; + } + try { + return valueOf(value.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "invalid mode: %s, valid values are: %s", value, VALID_VALUES) + ); } - return (score - lowerBoundScore) / (maxScore - lowerBoundScore); - } - }, - IGNORE { - @Override - public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - float normalizedScore = (score - minScore) / (maxScore - minScore); - return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; } - }; - public static final Mode DEFAULT = APPLY; - public static final String VALID_VALUES = Arrays.stream(values()) - .map(mode -> mode.name().toLowerCase(Locale.ROOT)) - .collect(Collectors.joining(", ")); + public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore); - public static Mode fromString(String value) { - if (value == null || value.trim().isEmpty()) { - throw new IllegalArgumentException("mode value cannot be null or empty"); - } - - try { - return valueOf(value.toUpperCase(Locale.ROOT)); - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "invalid mode: %s, valid values are: %s", value, VALID_VALUES) - ); + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); } } - - public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 9ad64da15..d1c2414a0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -21,7 +21,7 @@ public class ScoreNormalizationFactory { MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, MinMaxScoreNormalizationTechnique::new, L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - params -> new L2ScoreNormalizationTechnique(), + L2ScoreNormalizationTechnique::new, RRFNormalizationTechnique.TECHNIQUE_NAME, params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java index eabc69894..f9e93415e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java @@ -5,9 +5,11 @@ package org.opensearch.neuralsearch.processor; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.commons.lang3.RandomUtils; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; @@ -87,4 +89,152 @@ public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { assertNotNull(compoundTopDocsWithNullArray.getScoreDocs()); assertEquals(0, compoundTopDocsWithNullArray.getScoreDocs().size()); } + + public void testEqualsWithIdenticalCompoundTopDocs() { + TopDocs topDocs1 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + TopDocs topDocs2 = new TopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 2.0f) }); + List topDocsList = Arrays.asList(topDocs1, topDocs2); + + CompoundTopDocs first = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocsList, false, SEARCH_SHARD); + CompoundTopDocs second = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocsList, false, SEARCH_SHARD); + + assertTrue(first.equals(second)); + assertTrue(second.equals(first)); + assertEquals(first.hashCode(), second.hashCode()); + } + + public void testEqualsWithDifferentScoreDocs() { + TopDocs topDocs1 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + TopDocs topDocs2 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs1), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs2), + false, + SEARCH_SHARD + ); + + assertFalse(first.equals(second)); + assertFalse(second.equals(first)); + } + + public void testEqualsWithDifferentTotalHits() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + + assertFalse(first.equals(second)); + assertFalse(second.equals(first)); + } + + public void testEqualsWithDifferentSortEnabled() { + Object[] fields = new Object[] { "value1" }; + ScoreDoc scoreDoc = new FieldDoc(1, 1.0f, fields); // use FieldDoc when sort is enabled + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { scoreDoc }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + true, + SEARCH_SHARD + ); + + // non-sorted case, use regular ScoreDoc + TopDocs topDocs2 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs2), + false, + SEARCH_SHARD + ); + + assertNotEquals(first, second); + assertNotEquals(second, first); + } + + public void testEqualsWithDifferentSearchShards() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + new SearchShard("my_index", 1, "23456789") + ); + + assertNotEquals(first, second); + assertNotEquals(second, first); + } + + public void testEqualsWithFieldDocs() { + Object[] fields1 = new Object[] { "value1" }; + Object[] fields2 = new Object[] { "value1" }; + TopDocs topDocs1 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new FieldDoc[] { new FieldDoc(1, 1.0f, fields1) }); + TopDocs topDocs2 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new FieldDoc[] { new FieldDoc(1, 1.0f, fields2) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs1), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs2), + false, + SEARCH_SHARD + ); + + assertEquals(first, second); + assertEquals(second, first); + assertEquals(first.hashCode(), second.hashCode()); + } + + public void testEqualsWithNull() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + + assertNotEquals(null, compoundTopDocs); + } + + public void testEqualsWithDifferentClass() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + + assertNotEquals("not a CompoundTopDocs", compoundTopDocs); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index 840c19394..b00599cc2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -28,6 +28,7 @@ import static org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique.MIN_SCORE; import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS; /** * Abstracts normalization of scores based on min-max method @@ -274,86 +275,124 @@ public void testNormalizedScoresAreSetAtCorrectIndices() { assertEquals(1.0f, topDocs3.scoreDocs[0].score, DELTA_FOR_SCORE_ASSERTION); // doc1 in third subquery } - public void testMode_fromString_validValues() { - assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.fromString("apply")); - assertEquals(MinMaxScoreNormalizationTechnique.Mode.CLIP, MinMaxScoreNormalizationTechnique.Mode.fromString("clip")); - assertEquals(MinMaxScoreNormalizationTechnique.Mode.IGNORE, MinMaxScoreNormalizationTechnique.Mode.fromString("ignore")); + public void testLowerBoundsModeFromString_whenValidValues_thenSuccessful() { + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("apply") + ); + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.CLIP, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("clip") + ); + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.IGNORE, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("ignore") + ); // Case insensitive check - assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.fromString("APPLY")); + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("APPLY") + ); } public void testMode_fromString_invalidValues() { IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString("invalid") + () -> MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("invalid") ); assertEquals("invalid mode: invalid, valid values are: apply, clip, ignore", exception.getMessage()); } - public void testMode_fromString_nullOrEmpty() { + public void testLowerBoundsModeFromString_whenNullOrEmpty_thenFail() { IllegalArgumentException nullException = expectThrows( IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString(null) + () -> MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString(null) ); assertEquals("mode value cannot be null or empty", nullException.getMessage()); - - IllegalArgumentException emptyException = expectThrows( - IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString("") - ); - assertEquals("mode value cannot be null or empty", emptyException.getMessage()); } - public void testMode_normalize_apply() { + public void testLowerBounds_whenModeIsApply_thenSuccessful() { float score = 0.5f; - float minScore = 0.2f; + float minScore = 0.1f; float maxScore = 0.8f; float lowerBoundScore = 0.3f; - float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.APPLY.normalize(score, minScore, maxScore, lowerBoundScore); + float normalizedScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY.normalize( + score, + minScore, + maxScore, + lowerBoundScore + ); + // we expect score as 0.5 - 0.3 / 0.8 - 0.3 = 0.2 / 0.5 = 0.4 assertEquals(0.4f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); // Test when score is below lower bound - float lowScore = 0.1f; - float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.APPLY.normalize(lowScore, minScore, maxScore, lowerBoundScore); + float lowScore = 0.2f; + float normalizedLowScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY.normalize( + lowScore, + minScore, + maxScore, + lowerBoundScore + ); + // we expect score as 0.2 - 0.1 / 0.8 - 0.1 = 0.1 / 0.7 = 0.1 assertEquals(0.143f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_normalize_clip() { + public void testLowerBounds_whenModeIsClip_thenSuccessful() { float score = 0.5f; float minScore = 0.2f; float maxScore = 0.8f; float lowerBoundScore = 0.3f; - float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.CLIP.normalize(score, minScore, maxScore, lowerBoundScore); + float normalizedScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.CLIP.normalize( + score, + minScore, + maxScore, + lowerBoundScore + ); assertEquals(0.4f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); // Test when score is below min score float lowScore = 0.1f; - float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.CLIP.normalize(lowScore, minScore, maxScore, lowerBoundScore); - assertEquals(0.6f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); + float normalizedLowScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.CLIP.normalize( + lowScore, + minScore, + maxScore, + lowerBoundScore + ); + assertEquals(0.0f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_normalize_ignore() { + public void testLowerBounds_whenModeIsIgnore_thenSuccessful() { float score = 0.5f; float minScore = 0.2f; float maxScore = 0.8f; float lowerBoundScore = 0.3f; - float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBoundScore); + float normalizedScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.IGNORE.normalize( + score, + minScore, + maxScore, + lowerBoundScore + ); assertEquals(0.5f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); // Test when normalized score would be 0 float lowScore = 0.2f; - float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.IGNORE.normalize(lowScore, minScore, maxScore, lowerBoundScore); + float normalizedLowScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.IGNORE.normalize( + lowScore, + minScore, + maxScore, + lowerBoundScore + ); assertEquals(MIN_SCORE, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_defaultValue() { - assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.DEFAULT); + public void testLowerBoundsMode_whenDefaultValue_thenSuccessful() { + assertEquals(MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY, MinMaxScoreNormalizationTechnique.LowerBound.Mode.DEFAULT); } - public void testLowerBoundsExceedsMaxSubQueries() { + public void testLowerBounds_whenExceedsMaxSubQueries_thenFail() { List> lowerBounds = new ArrayList<>(); for (int i = 0; i <= 100; i++) { @@ -389,6 +428,163 @@ public void testLowerBoundsExceedsMaxSubQueries() { ); } + public void testDescribe_whenLowerBoundsArePresent_thenSuccessful() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList( + Map.of("mode", "apply", "min_score", 0.2), + + Map.of("mode", "clip", "min_score", 0.1) + ); + parameters.put("lower_bounds", lowerBounds); + MinMaxScoreNormalizationTechnique techniqueWithBounds = new MinMaxScoreNormalizationTechnique(parameters); + assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueWithBounds.describe()); + + // Test case 2: without lower bounds + Map emptyParameters = new HashMap<>(); + MinMaxScoreNormalizationTechnique techniqueWithoutBounds = new MinMaxScoreNormalizationTechnique(emptyParameters); + assertEquals("min_max", techniqueWithoutBounds.describe()); + + Map parametersMissingMode = new HashMap<>(); + List> lowerBoundsMissingMode = Arrays.asList( + Map.of("min_score", 0.2), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersMissingMode.put("lower_bounds", lowerBoundsMissingMode); + MinMaxScoreNormalizationTechnique techniqueMissingMode = new MinMaxScoreNormalizationTechnique(parametersMissingMode); + assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueMissingMode.describe()); + + Map parametersMissingScore = new HashMap<>(); + List> lowerBoundsMissingScore = Arrays.asList( + Map.of("mode", "apply"), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersMissingScore.put("lower_bounds", lowerBoundsMissingScore); + MinMaxScoreNormalizationTechnique techniqueMissingScore = new MinMaxScoreNormalizationTechnique(parametersMissingScore); + assertEquals("min_max, lower bounds [(apply, 0.0), (clip, 0.1)]", techniqueMissingScore.describe()); + } + + public void testLowerBounds_whenInvalidInput_thenFail() { + // Test case 1: Invalid mode value + Map parametersInvalidMode = new HashMap<>(); + List> lowerBoundsInvalidMode = Arrays.asList( + Map.of("mode", "invalid_mode", "min_score", 0.2), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersInvalidMode.put("lower_bounds", lowerBoundsInvalidMode); + IllegalArgumentException invalidModeException = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parametersInvalidMode) + ); + assertEquals("invalid mode: invalid_mode, valid values are: apply, clip, ignore", invalidModeException.getMessage()); + + // Test case 4: Invalid min_score type + Map parametersInvalidScore = new HashMap<>(); + List> lowerBoundsInvalidScore = Arrays.asList( + Map.of("mode", "apply", "min_score", "not_a_number"), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersInvalidScore.put("lower_bounds", lowerBoundsInvalidScore); + IllegalArgumentException invalidScoreException = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parametersInvalidScore) + ); + assertEquals("invalid format for min_score: must be a valid float value", invalidScoreException.getMessage()); + } + + public void testLowerBoundsValidation_whenLowerBoundsAndSubQueriesCountMismatch_thenFail() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.1f) }) + ), + false, + SEARCH_SHARD + ) + ); + ScoreNormalizationTechnique minMaxTechnique = new MinMaxScoreNormalizationTechnique(parameters); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(minMaxTechnique) + .build(); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> minMaxTechnique.normalize(normalizeScoresDTO) + ); + + assertEquals( + "expected lower bounds array to contain 2 elements matching the number of sub-queries, but found a mismatch", + exception.getMessage() + ); + } + + public void testLowerBoundsValidation_whenTopDocsIsEmpty_thenSuccessful() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList( + Map.of("mode", "clip", "min_score", 0.1), + Map.of("mode", "apply", "min_score", 0.0) + ); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + List compoundTopDocs = List.of( + new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), List.of(), false, SEARCH_SHARD), + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.1f) }) + ), + false, + SEARCH_SHARD + ) + ); + ScoreNormalizationTechnique minMaxTechnique = new MinMaxScoreNormalizationTechnique(parameters); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(minMaxTechnique) + .build(); + + minMaxTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocsZero = new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of(), + false, + SEARCH_SHARD + ); + CompoundTopDocs expectedCompoundDocsOne = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.25f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 1.0f) }) + ), + false, + SEARCH_SHARD + ); + expectedCompoundDocsOne.setScoreDocs(List.of(new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f))); + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + CompoundTopDocs compoundTopDocsZero = compoundTopDocs.get(0); + assertEquals(expectedCompoundDocsZero, compoundTopDocsZero); + CompoundTopDocs compoundTopDocsOne = compoundTopDocs.get(1); + assertEquals(expectedCompoundDocsOne, compoundTopDocsOne); + } + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.totalHits.value(), actual.totalHits.value()); assertEquals(expected.totalHits.relation(), actual.totalHits.relation()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java index cecdf8779..1adaa89d5 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -5,9 +5,15 @@ package org.opensearch.neuralsearch.processor.normalization; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + public class ScoreNormalizationFactoryTests extends OpenSearchQueryTestCase { public void testMinMaxNorm_whenCreatingByName_thenReturnCorrectInstance() { @@ -42,4 +48,32 @@ public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ); assertThat(illegalArgumentException.getMessage(), containsString("provided normalization technique is not supported")); } + + public void testCreateMinMaxNormalizationWithParameters() { + Map parameters = new HashMap<>(); + + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization("min_max", parameters); + + assertNotNull(normalizationTechnique); + assertTrue(normalizationTechnique instanceof MinMaxScoreNormalizationTechnique); + } + + public void testThrowsExceptionForInvalidTechniqueWithParameters() { + Map parameters = new HashMap<>(); + + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationFactory.createNormalization(L2ScoreNormalizationTechnique.TECHNIQUE_NAME, parameters) + ); + assertEquals("unrecognized parameters in normalization technique", exception.getMessage()); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index 3fe39554e..25b96838b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -129,71 +129,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { // explain Map searchHit1 = hitsNestedList.get(0); Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); - assertNotNull(topLevelExplanationsHit1); - assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); - String expectedTopLevelDescription = "arithmetic_mean combination of:"; - assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); - List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); - assertEquals(1, normalizationExplanationHit1.size()); - Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); - assertEquals(1.0, hit1DetailsForHit1.get("value")); - assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); - assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); - - Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); - assertEquals("sum of:", explanationsHit1.get("description")); - assertEquals(0.343f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals(1, ((List) explanationsHit1.get("details")).size()); - - // search hit 2 - Map searchHit2 = hitsNestedList.get(1); - Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); - assertNotNull(topLevelExplanationsHit2); - assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); - List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); - assertEquals(1, normalizationExplanationHit2.size()); - - Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); - assertEquals(1.0, hit1DetailsForHit2.get("value")); - assertEquals("min_max normalization of:", hit1DetailsForHit2.get("description")); - assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); - - Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); - assertEquals(0.13f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit2.get("description")); - assertEquals(1, getListOfValues(explanationsHit2, "details").size()); - - Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); - assertEquals(0.13f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2Details.get("description")); - assertEquals(2, getListOfValues(explanationsHit2Details, "details").size()); - - // search hit 3 - Map searchHit3 = hitsNestedList.get(1); - Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); - assertNotNull(topLevelExplanationsHit3); - assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); - List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); - assertEquals(1, normalizationExplanationHit3.size()); - - Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); - assertEquals(1.0, hit1DetailsForHit3.get("value")); - assertEquals("min_max normalization of:", hit1DetailsForHit3.get("description")); - assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); - - Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); - assertEquals(0.13f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit3.get("description")); - assertEquals(1, getListOfValues(explanationsHit3, "details").size()); - - Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); - assertEquals(0.13f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); - assertEquals(2, getListOfValues(explanationsHit3Details, "details").size()); + assertExplanation(topLevelExplanationsHit1, searchHit1, hitsNestedList, false); } @SneakyThrows @@ -732,6 +668,154 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { assertTrue((double) explanationsHit4.get("value") > 0.0f); } + @SneakyThrows + public void testExplain_whenMinMaxNormalizationWithLowerBounds_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline( + NORMALIZATION_SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + Map.of( + "lower_bounds", + List.of( + Map.of("mode", "apply", "min_score", Float.toString(0.01f)), + Map.of("mode", "clip", "min_score", Float.toString(0.0f)) + ) + ), + DEFAULT_COMBINATION_METHOD, + Map.of(), + true + ); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", "true") + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); + assertExplanation(topLevelExplanationsHit1, searchHit1, hitsNestedList, true); + } + + private void assertExplanation( + Map topLevelExplanationsHit1, + Map searchHit1, + List> hitsNestedList, + boolean withLowerBounds + ) { + assertNotNull(topLevelExplanationsHit1); + assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "arithmetic_mean combination of:"; + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); + List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); + assertEquals(1, normalizationExplanationHit1.size()); + Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + if (withLowerBounds) { + assertEquals("min_max, lower bounds [(apply, 0.01), (clip, 0.0)] normalization of:", hit1DetailsForHit1.get("description")); + } else { + assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); + } + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("sum of:", explanationsHit1.get("description")); + assertEquals(0.343f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, ((List) explanationsHit1.get("details")).size()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(topLevelExplanationsHit2); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); + List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); + assertEquals(1, normalizationExplanationHit2.size()); + + Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); + assertEquals(1.0, hit1DetailsForHit2.get("value")); + if (withLowerBounds) { + assertEquals("min_max, lower bounds [(apply, 0.01), (clip, 0.0)] normalization of:", hit1DetailsForHit2.get("description")); + } else { + assertEquals("min_max normalization of:", hit1DetailsForHit2.get("description")); + } + assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals(0.13f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit2.get("description")); + assertEquals(1, getListOfValues(explanationsHit2, "details").size()); + + Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); + assertEquals(0.13f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2Details.get("description")); + assertEquals(2, getListOfValues(explanationsHit2Details, "details").size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(1); + Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(topLevelExplanationsHit3); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); + List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); + assertEquals(1, normalizationExplanationHit3.size()); + + Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); + assertEquals(1.0, hit1DetailsForHit3.get("value")); + if (withLowerBounds) { + assertEquals("min_max, lower bounds [(apply, 0.01), (clip, 0.0)] normalization of:", hit1DetailsForHit3.get("description")); + } else { + assertEquals("min_max normalization of:", hit1DetailsForHit3.get("description")); + } + assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); + + Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); + assertEquals(0.13f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit3.get("description")); + assertEquals(1, getListOfValues(explanationsHit3, "details").size()); + + Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); + assertEquals(0.13f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); + assertEquals(2, getListOfValues(explanationsHit3Details, "details").size()); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) {