From 44658320ca1bf6db507a62aa4bb670cc64bcb5d8 Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Thu, 13 Feb 2025 10:47:58 -0800 Subject: [PATCH] Clean Code Signed-off-by: Vikasht34 --- .../quantizationState/OneBitScalarQuantizationState.java | 2 +- .../knn/quantization/quantizer/QuantizerHelper.java | 6 +++++- .../NativeEngines990KnnVectorsWriterFlushTests.java | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index ba79506db..b5a2638dc 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -73,7 +73,7 @@ public ScalarQuantizationParams getQuantizationParams() { */ @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(Version.CURRENT.id); // Write the versionF + out.writeVInt(Version.CURRENT.id); // Write the version quantizationParams.writeTo(out); out.writeFloatArray(meanThresholds); out.writeOptionalArray(belowThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(belowThresholdMeans) } : null); diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index b4038a9cf..ec05b6001 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -20,6 +20,10 @@ */ @UtilityClass class QuantizerHelper { + + // This value can change based on Experiments. + private static final double ROTATION_MATRIX_THRESHOLD = 0.6; + /** * Calculates the quantization state using the provided training data and sampled indices. *

@@ -48,7 +52,7 @@ static OneBitScalarQuantizationState calculateQuantizationState( double averageL2L1Ratio = meanAndL2L1.getB(); // Apply random rotation if L2/L1 ratio is greater than 0.6 float[][] rotationMatrix = null; - if (averageL2L1Ratio > 0.6) { + if (averageL2L1Ratio > ROTATION_MATRIX_THRESHOLD) { int dimensions = meanThresholds.length; rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dimensions); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 0d022b88f..5811eccd2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -302,7 +302,7 @@ public void testFlush_WithQuantization() { final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 1) + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) ); } }