Skip to content

Commit

Permalink
Added check for number of elements in lower_bounds array
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 20, 2025
1 parent b63f34f commit a983cfd
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;
import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;

/**
* Abstracts normalization of scores based on min-max method
Expand All @@ -44,10 +45,10 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech
private final List<Pair<Mode, Float>> lowerBounds;

public MinMaxScoreNormalizationTechnique() {
this(Map.of(), new ScoreNormalizationUtil());
this(Map.of());
}

public MinMaxScoreNormalizationTechnique(final Map<String, Object> params, final ScoreNormalizationUtil scoreNormalizationUtil) {
public MinMaxScoreNormalizationTechnique(final Map<String, Object> params) {
lowerBounds = getLowerBounds(params);
}

Expand Down Expand Up @@ -238,6 +239,17 @@ private List<Pair<Mode, Float>> getLowerBounds(final Map<String, Object> params)
throw new IllegalArgumentException("lower_bounds must be a List");
}

if (lowerBoundsParams.size() > MAX_NUMBER_OF_SUB_QUERIES) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"lower_bounds size %d should be less than or equal to %d",
lowerBoundsParams.size(),
MAX_NUMBER_OF_SUB_QUERIES
)
);
}

for (Object boundObj : lowerBoundsParams) {
if (!(boundObj instanceof Map)) {
throw new IllegalArgumentException("each lower bound must be a map");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@ public class ScoreNormalizationFactory {

private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil();

public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(
Map.of(),
scoreNormalizationUtil
);
public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(Map.of());

private final Map<String, Function<Map<String, Object>, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of(
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
params -> new MinMaxScoreNormalizationTechnique(params, scoreNormalizationUtil),
MinMaxScoreNormalizationTechnique::new,
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
params -> new L2ScoreNormalizationTechnique(),
RRFNormalizationTechnique.TECHNIQUE_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private Integer paginationDepth;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
public static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0;

public HybridQueryBuilder(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
*/
package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.apache.commons.lang3.tuple.Pair;
Expand All @@ -23,6 +26,7 @@
import org.opensearch.search.SearchShardTarget;

import static org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique.MIN_SCORE;
import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES;
import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;

/**
Expand Down Expand Up @@ -349,6 +353,42 @@ public void testMode_defaultValue() {
assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.DEFAULT);
}

public void testLowerBoundsExceedsMaxSubQueries() {
List<Map<String, Object>> lowerBounds = new ArrayList<>();

for (int i = 0; i <= 100; i++) {
Map<String, Object> bound = new HashMap<>();
if (i % 3 == 0) {
bound.put("mode", "apply");
bound.put("min_score", 0.1f);
} else if (i % 3 == 1) {
bound.put("mode", "clip");
bound.put("min_score", 0.1f);
} else {
bound.put("mode", "ignore");
}
lowerBounds.add(bound);
}

Map<String, Object> parameters = new HashMap<>();
parameters.put("lower_bounds", lowerBounds);

IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> new MinMaxScoreNormalizationTechnique(parameters)
);

assertEquals(
String.format(
Locale.ROOT,
"lower_bounds size %d should be less than or equal to %d",
lowerBounds.size(),
MAX_NUMBER_OF_SUB_QUERIES
),
exception.getMessage()
);
}

private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) {
assertEquals(expected.totalHits.value(), actual.totalHits.value());
assertEquals(expected.totalHits.relation(), actual.totalHits.relation());
Expand Down

0 comments on commit a983cfd

Please sign in to comment.