Skip to content

Commit

Permalink
Added more validations and unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Feb 26, 2025
1 parent a983cfd commit 199798d
Show file tree
Hide file tree
Showing 9 changed files with 767 additions and 174 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -150,4 +151,78 @@ private ScoreDoc deepCopyScoreDoc(final ScoreDoc scoreDoc, final boolean isSortE
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
CompoundTopDocs that = (CompoundTopDocs) other;

if (this.topDocs.size() != that.topDocs.size()) {
return false;
}
for (int i = 0; i < topDocs.size(); i++) {
TopDocs thisTopDoc = this.topDocs.get(i);
TopDocs thatTopDoc = that.topDocs.get(i);
if ((thisTopDoc == null) != (thatTopDoc == null)) {
return false;
}
if (thisTopDoc == null) {
continue;
}
if (!Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits)) {
return false;
}
if (!compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs)) {
return false;
}
}
return Objects.equals(totalHits, that.totalHits) && Objects.equals(searchShard, that.searchShard);
}

private boolean compareScoreDocs(ScoreDoc[] first, ScoreDoc[] second) {
if (first.length != second.length) {
return false;
}

for (int i = 0; i < first.length; i++) {
ScoreDoc firstDoc = first[i];
ScoreDoc secondDoc = second[i];
if ((firstDoc == null) != (secondDoc == null)) {
return false;
}
if (firstDoc == null) {
continue;
}
if (firstDoc.doc != secondDoc.doc || Float.compare(firstDoc.score, secondDoc.score) != 0) {
return false;
}
if (firstDoc instanceof FieldDoc != secondDoc instanceof FieldDoc) {
return false;
}
if (firstDoc instanceof FieldDoc firstFieldDoc) {
FieldDoc secondFieldDoc = (FieldDoc) secondDoc;
if (!Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields)) {
return false;
}
}
}
return true;
}

@Override
public int hashCode() {
int result = Objects.hash(totalHits, searchShard);
for (TopDocs topDoc : topDocs) {
result = 31 * result + topDoc.totalHits.hashCode();
for (ScoreDoc scoreDoc : topDoc.scoreDocs) {
result = 31 * result + Float.floatToIntBits(scoreDoc.score);
result = 31 * result + scoreDoc.doc;
if (scoreDoc instanceof FieldDoc fieldDoc && fieldDoc.fields != null) {
result = 31 * result + Arrays.deepHashCode(fieldDoc.fields);
}
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu
public static final String TECHNIQUE_NAME = "l2";
private static final float MIN_SCORE = 0.0f;

public L2ScoreNormalizationTechnique() {
this(Map.of());
}

public L2ScoreNormalizationTechnique(final Map<String, Object> params) {
if (Objects.nonNull(params) && !params.isEmpty()) {
throw new IllegalArgumentException("unrecognized parameters in normalization technique");
}
}

/**
* L2 normalization method.
* n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2)
Expand Down
Loading

0 comments on commit 199798d

Please sign in to comment.