From 72ce48f32f06f5c8736891275c5327379a3e927f Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Tue, 11 Feb 2025 16:57:04 -0800 Subject: [PATCH] Implementation of Random Rotation Matrix with Query Scaling matix of above and below thresholds for 1 bit Binary Quantization Signed-off-by: Vikasht34 --- CHANGELOG.md | 2 +- build.gradle | 2 +- .../NativeEngines990KnnVectorsWriter.java | 3 +- .../KNNVectorQuantizationTrainingRequest.java | 21 +- .../QuantizationService.java | 12 +- .../OneBitScalarQuantizationState.java | 101 +++++++- .../models/requests/TrainingRequest.java | 2 + .../quantizer/OneBitScalarQuantizer.java | 13 +- .../quantizer/QuantizerHelper.java | 180 ++++++++++++-- .../quantizer/RandomGaussianRotation.java | 89 +++++++ .../training/FloatTrainingDataConsumer.java | 6 + .../KNN990QuantizationStateWriterTests.java | 24 +- ...eEngines990KnnVectorsWriterFlushTests.java | 15 +- ...eEngines990KnnVectorsWriterMergeTests.java | 10 +- .../QuantizationIndexUtilsTests.java | 5 +- .../QuantizationServiceTests.java | 48 +++- .../knn/index/query/KNNWeightTests.java | 5 +- .../QuantizationStateCacheTests.java | 77 +++--- .../QuantizationStateSerializerTests.java | 5 +- .../QuantizationStateTests.java | 111 ++++++++- .../MultiBitScalarQuantizerTests.java | 5 + .../quantizer/OneBitScalarQuantizerTests.java | 230 ++++++++++++++++-- .../RandomGaussianRotationTests.java | 126 ++++++++++ 23 files changed, 949 insertions(+), 143 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotation.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotationTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index ccb5fe2fb..ed3959e1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,4 +59,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Bump Faiss commit from 1f42e81 to 0cbc2a8 to accelerate hamming distance calculation using _mm512_popcnt_epi64 intrinsic and also add avx512-fp16 instructions to boost performance [#2381](https://github.com/opensearch-project/k-NN/pull/2381) * Enabled indices.breaker.total.use_real_memory setting via build.gradle for integTest Cluster to catch heap CB in local ITs and github CI actions [#2395](https://github.com/opensearch-project/k-NN/pull/2395/) * Enabled idempotency of local builds when using `./gradlew clean` and nest `jni/release` directory under `jni/build` for easier cleanup [#2516](https://github.com/opensearch-project/k-NN/pull/2516) -### Refactoring +### Refactoring \ No newline at end of file diff --git a/build.gradle b/build.gradle index b5f715847..04cc2617d 100644 --- a/build.gradle +++ b/build.gradle @@ -595,4 +595,4 @@ task updateVersion { ant.replaceregexp(file:".github/workflows/backwards_compatibility_tests_workflow.yml", match: oldBWCVersion, replace: oldBWCVersion + '", "' + opensearch_version.tokenize('-')[0], flags:'g', byline:true) } } -} +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 3966a2c95..dc2a5ed3d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -256,8 +256,7 @@ private QuantizationState train( QuantizationState quantizationState = null; if (quantizationParams != null && totalLiveDocs > 0) { initQuantizationStateWriterIfNecessary(); - KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); + quantizationState = quantizationService.train(quantizationParams, knnVectorValuesSupplier, totalLiveDocs); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java index f7ee12904..4c11e5e1f 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java @@ -12,6 +12,7 @@ import java.io.IOException; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import java.util.function.Supplier; /** * KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class. @@ -19,18 +20,19 @@ */ @Log4j2 final class KNNVectorQuantizationTrainingRequest extends TrainingRequest { - - private final KNNVectorValues knnVectorValues; + private final Supplier> knnVectorValuesSupplier; + private KNNVectorValues knnVectorValues; private int lastIndex; /** * Constructs a new QuantizationFloatVectorTrainingRequest. * - * @param knnVectorValues the KNNVectorValues instance containing the vectors. + * @param knnVectorValuesSupplier the KNNVectorValues instance containing the vectors. */ - KNNVectorQuantizationTrainingRequest(KNNVectorValues knnVectorValues, long liveDocs) { + KNNVectorQuantizationTrainingRequest(Supplier> knnVectorValuesSupplier, long liveDocs) { super((int) liveDocs); - this.knnVectorValues = knnVectorValues; + this.knnVectorValuesSupplier = knnVectorValuesSupplier; + resetVectorValues(); // Initialize the first instance this.lastIndex = 0; } @@ -52,4 +54,13 @@ public T getVectorAtThePosition(int position) throws IOException { // Return the vector return knnVectorValues.getVector(); } + + /** + * Resets the KNNVectorValues to enable a fresh iteration by calling the supplier again. + */ + @Override + public void resetVectorValues() { + this.knnVectorValues = knnVectorValuesSupplier.get(); + this.lastIndex = 0; + } } diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index 771848730..91690a41e 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -19,6 +19,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; +import java.util.function.Supplier; import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig; @@ -53,19 +54,22 @@ public static QuantizationService getInstance() { * {@link QuantizationState}. The quantizer is determined based on the given {@link QuantizationParams}. * * @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization. - * @param knnVectorValues The {@link KNNVectorValues} representing the vector data to be used for training. + * @param knnVectorValuesSupplier The {@link KNNVectorValues} representing the vector data to be used for training. * @return The {@link QuantizationState} containing the state of the trained quantizer. * @throws IOException If an I/O error occurs during the training process. */ public QuantizationState train( final QuantizationParams quantizationParams, - final KNNVectorValues knnVectorValues, + final Supplier> knnVectorValuesSupplier, final long liveDocs ) throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); - // Create the training request from the vector values - KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); + // Create the training request using the supplier + KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>( + knnVectorValuesSupplier, + liveDocs + ); // Train the quantizer and return the quantization state return quantizer.train(trainingRequest); diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 0a8c33771..6ef616b15 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -5,13 +5,12 @@ package org.opensearch.knn.quantization.models.quantizationState; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; +import lombok.*; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import java.io.IOException; @@ -21,10 +20,12 @@ * including the mean values used for quantization. */ @Getter -@NoArgsConstructor // No-argument constructor for deserialization +@Builder @AllArgsConstructor +@NoArgsConstructor(force = true) public final class OneBitScalarQuantizationState implements QuantizationState { - private ScalarQuantizationParams quantizationParams; + @NonNull + private final ScalarQuantizationParams quantizationParams; /** * Mean thresholds used in the quantization process. * Each threshold value corresponds to a dimension of the vector being quantized. @@ -33,7 +34,27 @@ public final class OneBitScalarQuantizationState implements QuantizationState { * If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0], * The quantized vector will be [0, 1, 1]. */ - private float[] meanThresholds; + @NonNull + private final float[] meanThresholds; + + /** + * Represents the mean of all values below the threshold for each dimension. + */ + @Builder.Default + private float[] belowThresholdMeans = null; + + /** + * Represents the mean of all values above the threshold for each dimension. + */ + @Builder.Default + private float[] aboveThresholdMeans = null; + @Builder.Default + private double averageL2L1Ratio = 0.0; + /** + * Rotation matrix used when L2/L1 ratio > 0.6 + */ + @Builder.Default + private float[][] rotationMatrix = null; @Override public ScalarQuantizationParams getQuantizationParams() { @@ -48,9 +69,23 @@ public ScalarQuantizationParams getQuantizationParams() { */ @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(Version.CURRENT.id); // Write the version + out.writeVInt(Version.CURRENT.id); // Write the versionF quantizationParams.writeTo(out); out.writeFloatArray(meanThresholds); + out.writeOptionalArray(belowThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(belowThresholdMeans) } : null); + // Serialize aboveThresholdMeans using writeOptionalArray + out.writeOptionalArray(aboveThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(aboveThresholdMeans) } : null); + out.writeOptionalDouble(averageL2L1Ratio); + // Write rotation matrix + if (rotationMatrix != null) { + out.writeBoolean(true); + out.writeVInt(rotationMatrix.length); + for (float[] row : rotationMatrix) { + out.writeFloatArray(row); + } + } else { + out.writeBoolean(false); + } } /** @@ -63,6 +98,23 @@ public OneBitScalarQuantizationState(StreamInput in) throws IOException { int version = in.readVInt(); // Read the version this.quantizationParams = new ScalarQuantizationParams(in, version); this.meanThresholds = in.readFloatArray(); + if (Version.fromId(version).onOrAfter(Version.V_3_0_0)) { + // Deserialize belowThresholdMeans using readOptionalArray + FloatArrayWrapper[] wrappedBelowThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new); + this.belowThresholdMeans = wrappedBelowThresholdMeans != null ? wrappedBelowThresholdMeans[0].getArray() : null; + // Deserialize aboveThresholdMeans using readOptionalArray + FloatArrayWrapper[] wrappedAboveThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new); + this.aboveThresholdMeans = wrappedAboveThresholdMeans != null ? wrappedAboveThresholdMeans[0].getArray() : null; + this.averageL2L1Ratio = in.readOptionalDouble(); + // Read rotation matrix + if (in.readBoolean()) { + int dimensions = in.readVInt(); + this.rotationMatrix = new float[dimensions][]; + for (int i = 0; i < dimensions; i++) { + this.rotationMatrix[i] = in.readFloatArray(); + } + } + } } /** @@ -139,6 +191,41 @@ public long ramBytesUsed() { long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class); size += RamUsageEstimator.shallowSizeOf(quantizationParams); size += RamUsageEstimator.sizeOf(meanThresholds); + if (belowThresholdMeans != null) { + size += RamUsageEstimator.sizeOf(belowThresholdMeans); + } + if (aboveThresholdMeans != null) { + size += RamUsageEstimator.sizeOf(aboveThresholdMeans); + } + if (rotationMatrix != null) { + size += RamUsageEstimator.shallowSizeOf(rotationMatrix); + // Add size of each row array + for (float[] row : rotationMatrix) { + size += RamUsageEstimator.sizeOf(row); + } + } return size; } + + private class FloatArrayWrapper implements Writeable { + private final float[] array; + + public FloatArrayWrapper(float[] array) { + this.array = array; + } + + // Constructor that matches Writeable.Reader + public FloatArrayWrapper(StreamInput in) throws IOException { + this.array = in.readFloatArray(); + } + + public float[] getArray() { + return array; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeFloatArray(array); + } + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java index d8b0eab10..730aeed1e 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -30,4 +30,6 @@ public abstract class TrainingRequest { * @return the vector corresponding to the specified document ID. */ public abstract T getVectorAtThePosition(int position) throws IOException; + + public abstract void resetVectorValues(); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index 3cba89c39..61056c37e 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -60,8 +60,11 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { @Override public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); - return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); + return QuantizerHelper.calculateQuantizationState( + trainingRequest, + sampledDocIds, + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT) + ); } /** @@ -73,7 +76,7 @@ public QuantizationState train(final TrainingRequest trainingRequest) t * @param output the QuantizationOutput object to store the quantized representation of the vector. */ @Override - public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) { + public void quantize(float[] vector, final QuantizationState state, final QuantizationOutput output) { if (vector == null) { throw new IllegalArgumentException("Vector to quantize must not be null."); } @@ -84,6 +87,10 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds.length != vectorLength) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } + float[][] rotationMatrix = binaryState.getRotationMatrix(); + if (rotationMatrix != null) { + vector = RandomGaussianRotation.applyRotation(vector, rotationMatrix); + } output.prepareQuantizedVector(vectorLength); BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector()); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index bac2067c0..ffcf96c91 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -5,6 +5,8 @@ package org.opensearch.knn.quantization.quantizer; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import lombok.experimental.UtilityClass; import oshi.util.tuples.Pair; @@ -19,37 +21,178 @@ @UtilityClass class QuantizerHelper { /** - * Calculates the mean vector from a set of sampled vectors. + * Calculates the quantization state using the provided training data and sampled indices. + *

+ * This method combines the calculation of mean thresholds, average L2/L1 ratio, and + * below/above threshold means to construct a {@link OneBitScalarQuantizationState}. + *

+ * + * @param trainingRequest The {@link TrainingRequest} containing the dataset and access methods for vector retrieval. + * @param sampledIndices An array of indices representing the sampled vectors. + * @param quantizationParams The scalar quantization parameters. + * @return A fully constructed {@link OneBitScalarQuantizationState}. + * @throws IOException If an I/O error occurs while retrieving vector data. + */ + static OneBitScalarQuantizationState calculateQuantizationState( + TrainingRequest trainingRequest, + int[] sampledIndices, + ScalarQuantizationParams quantizationParams + ) throws IOException { + if (sampledIndices.length == 0) { + throw new IllegalArgumentException("No samples provided."); + } + + // Calculate mean thresholds and L2/L1 ratio in a single pass + Pair meanAndL2L1 = calculateMeanAndL2L1Ratio(trainingRequest, sampledIndices); + float[] meanThresholds = meanAndL2L1.getA(); + double averageL2L1Ratio = meanAndL2L1.getB(); + // Apply random rotation if L2/L1 ratio is greater than 0.6 + float[][] rotationMatrix = null; + if (averageL2L1Ratio > 0.6) { + int dimensions = meanThresholds.length; + rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dimensions); + + // Apply rotation to mean thresholds + meanThresholds = RandomGaussianRotation.applyRotation(meanThresholds, rotationMatrix); + } + + // Calculate below and above threshold means + Pair belowAboveMeans = calculateBelowAboveThresholdMeans(trainingRequest, meanThresholds, sampledIndices); + float[] belowThresholdMeans = belowAboveMeans.getA(); + float[] aboveThresholdMeans = belowAboveMeans.getB(); + + // Apply the same rotation to below and above threshold means if rotation was applied + if (rotationMatrix != null) { + belowThresholdMeans = RandomGaussianRotation.applyRotation(belowThresholdMeans, rotationMatrix); + aboveThresholdMeans = RandomGaussianRotation.applyRotation(aboveThresholdMeans, rotationMatrix); + } + + // Construct and return the quantization state + return OneBitScalarQuantizationState.builder() + .quantizationParams(quantizationParams) + .meanThresholds(meanThresholds) + .belowThresholdMeans(belowThresholdMeans) + .aboveThresholdMeans(aboveThresholdMeans) + .averageL2L1Ratio(averageL2L1Ratio) + .rotationMatrix(rotationMatrix) + .build(); + } + + /** + * Calculates the mean thresholds and average L2/L1 ratio for the given sampled vectors. + *

+ * The mean thresholds are computed by averaging the values across all sampled vectors for each dimension. + * The average L2/L1 ratio is calculated as the mean of the L2/L1 ratio for each vector. + *

* - * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. - * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. - * @return A float array representing the mean vector of the sampled vectors. - * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. - * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. + * @param trainingRequest The {@link TrainingRequest} containing the dataset and access methods for vector retrieval. + * @param sampledIndices An array of indices representing the sampled vectors. + * @return A {@link Pair} where the first element is the array of mean thresholds (float[]) and the second element + * is the average L2/L1 ratio (Double). + * @throws IOException If an I/O error occurs while retrieving vector data. + * @throws IllegalArgumentException If any vector at the sampled indices is null. */ - static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) throws IOException { + private static Pair calculateMeanAndL2L1Ratio(TrainingRequest trainingRequest, int[] sampledIndices) + throws IOException { + float[] meanThresholds = null; + double totalL2L1Ratio = 0.0; int totalSamples = sampledIndices.length; - float[] mean = null; - int lastIndex = 0; + for (int docId : sampledIndices) { - float[] vector = samplingRequest.getVectorAtThePosition(docId); + float[] vector = trainingRequest.getVectorAtThePosition(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } - if (mean == null) { - mean = new float[vector.length]; + + if (meanThresholds == null) { + meanThresholds = new float[vector.length]; } + + double l2Norm = 0.0; + double l1Norm = 0.0; + for (int j = 0; j < vector.length; j++) { - mean[j] += vector[j]; + float value = vector[j]; + + // Accumulate mean + meanThresholds[j] += value; + + // Accumulate norms + l2Norm += value * value; + l1Norm += Math.abs(value); } + + // Update L2/L1 ratio for the vector + totalL2L1Ratio += Math.sqrt(l2Norm) / l1Norm; } - if (mean == null) { - throw new IllegalStateException("Mean array should not be null after processing vectors."); + + // Finalize mean thresholds + for (int j = 0; j < meanThresholds.length; j++) { + meanThresholds[j] /= totalSamples; } - for (int j = 0; j < mean.length; j++) { - mean[j] /= totalSamples; + + // Calculate average L2/L1 ratio + double averageL2L1Ratio = totalL2L1Ratio / totalSamples; + + return new Pair<>(meanThresholds, averageL2L1Ratio); + } + + /** + * Calculates the below and above threshold means for the given sampled vectors. + *

+ * For each dimension, values are classified as either below or above the mean threshold, + * and their respective means are calculated. + *

+ * + * @param trainingRequest The {@link TrainingRequest} containing the dataset and access methods for vector retrieval. + * @param thresholds The mean thresholds for each dimension. + * @param sampledIndices An array of indices representing the sampled vectors. + * @return A {@link Pair} containing two float arrays: + * - The first array represents the below threshold means. + * - The second array represents the above threshold means. + * @throws IOException If an I/O error occurs while retrieving vector data. + */ + private static Pair calculateBelowAboveThresholdMeans( + TrainingRequest trainingRequest, + float[] thresholds, + int[] sampledIndices + ) throws IOException { + int dimension = thresholds.length; + float[] belowThresholdMeans = new float[dimension]; + float[] aboveThresholdMeans = new float[dimension]; + int[] belowThresholdCounts = new int[dimension]; + int[] aboveThresholdCounts = new int[dimension]; + + for (int docId : sampledIndices) { + float[] vector = trainingRequest.getVectorAtThePosition(docId); + if (vector == null) { + continue; + } + + for (int j = 0; j < dimension; j++) { + float value = vector[j]; + + if (value <= thresholds[j]) { + belowThresholdMeans[j] += value; + belowThresholdCounts[j]++; + } else { + aboveThresholdMeans[j] += value; + aboveThresholdCounts[j]++; + } + } } - return mean; + + // Finalize means + for (int j = 0; j < dimension; j++) { + if (belowThresholdCounts[j] > 0) { + belowThresholdMeans[j] /= belowThresholdCounts[j]; + } + if (aboveThresholdCounts[j] > 0) { + aboveThresholdMeans[j] /= aboveThresholdCounts[j]; + } + } + + return new Pair<>(belowThresholdMeans, aboveThresholdMeans); } /** @@ -71,7 +214,6 @@ static Pair calculateMeanAndStdDev(TrainingRequest tr float[] meanArray = null; float[] stdDevArray = null; int totalSamples = sampledIndices.length; - int lastIndex = 0; for (int docId : sampledIndices) { float[] vector = trainingRequest.getVectorAtThePosition(docId); if (vector == null) { diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotation.java b/src/main/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotation.java new file mode 100644 index 000000000..1c41870f0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotation.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import lombok.experimental.UtilityClass; + +import java.util.Random; + +@UtilityClass +public class RandomGaussianRotation { + + /** + * Generates a random rotation matrix using Gaussian distribution and orthogonalization. + * + * @param dimensions The number of dimensions for the rotation matrix. + * @return A 2D float array representing the rotation matrix. + */ + public float[][] generateRotationMatrix(int dimensions) { + Random random = new Random(); + float[][] rotationMatrix = new float[dimensions][dimensions]; + + // Step 1: Generate random Gaussian values + for (int i = 0; i < dimensions; i++) { + for (int j = 0; j < dimensions; j++) { + rotationMatrix[i][j] = (float) random.nextGaussian(); + } + } + + // Step 2: Orthogonalize the matrix using the Gram-Schmidt process + for (int i = 0; i < dimensions; i++) { + // Normalize the current vector + float norm = 0f; + for (int j = 0; j < dimensions; j++) { + norm += rotationMatrix[i][j] * rotationMatrix[i][j]; + } + norm = (float) Math.sqrt(norm); + for (int j = 0; j < dimensions; j++) { + rotationMatrix[i][j] /= norm; + } + + // Subtract projections of the current vector onto all previous vectors + for (int k = 0; k < i; k++) { + float dotProduct = 0f; + for (int j = 0; j < dimensions; j++) { + dotProduct += rotationMatrix[i][j] * rotationMatrix[k][j]; + } + for (int j = 0; j < dimensions; j++) { + rotationMatrix[i][j] -= dotProduct * rotationMatrix[k][j]; + } + } + + // Re-normalize after orthogonalization + norm = 0f; + for (int j = 0; j < dimensions; j++) { + norm += rotationMatrix[i][j] * rotationMatrix[i][j]; + } + norm = (float) Math.sqrt(norm); + for (int j = 0; j < dimensions; j++) { + rotationMatrix[i][j] /= norm; + } + } + + return rotationMatrix; + } + + /** + * Applies a rotation to a vector using the provided rotation matrix. + * + * @param vector The input vector to be rotated. + * @param rotationMatrix The rotation matrix. + * @return The rotated vector. + */ + public float[] applyRotation(float[] vector, float[][] rotationMatrix) { + int dimensions = vector.length; + float[] rotatedVector = new float[dimensions]; + + for (int i = 0; i < dimensions; i++) { + rotatedVector[i] = 0f; + for (int j = 0; j < dimensions; j++) { + rotatedVector[i] += rotationMatrix[i][j] * vector[j]; + } + } + + return rotatedVector; + } +} diff --git a/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java b/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java index 292752945..b41d5a0d4 100644 --- a/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java +++ b/src/main/java/org/opensearch/knn/training/FloatTrainingDataConsumer.java @@ -98,6 +98,12 @@ private List quantizeVectors(List vectors) throws IOException { public float[] getVectorAtThePosition(int position) { return ArrayUtils.toPrimitive((Float[]) vectors.get(position)); } + + @Override + public void resetVectorValues() { + // No-op + } + }; QuantizationState quantizationState = quantizer.train(trainingRequest); BinaryQuantizationOutput binaryQuantizationOutput = new BinaryQuantizationOutput(quantizationConfig.getQuantizationType().getId()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java index 2423a6827..319a71df0 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateWriterTests.java @@ -120,10 +120,10 @@ public void testWriteState() { KNN990QuantizationStateWriter quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); int fieldNumber = 0; - QuantizationState quantizationState = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f, 4.5f } - ); + QuantizationState quantizationState = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); quantizationStateWriter.writeState(fieldNumber, quantizationState); byte[] stateBytes = quantizationState.toByteArray(); Mockito.verify(output, times(1)).writeBytes(stateBytes, stateBytes.length); @@ -164,14 +164,14 @@ public void testWriteFooter() { int fieldNumber1 = 1; int fieldNumber2 = 2; - QuantizationState quantizationState1 = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f, 4.5f } - ); - QuantizationState quantizationState2 = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 2.3f, 3.4f, 4.5f, 5.6f } - ); + QuantizationState quantizationState1 = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); + QuantizationState quantizationState2 = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 2.3f, 3.4f, 4.5f, 5.6f }) + .build(); quantizationStateWriter.writeState(fieldNumber1, quantizationState1); quantizationStateWriter.writeState(fieldNumber2, quantizationState2); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 6685e2b22..e559ad04a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -40,6 +40,7 @@ import java.util.List; import java.util.Map; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -47,6 +48,7 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; @@ -257,7 +259,8 @@ public void testFlush_WithQuantization() { when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + // Fix mock to use the supplier + when(quantizationService.train(eq(quantizationParams), any(Supplier.class), eq((long) vectorsPerField.get(i).size()))) .thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); @@ -299,7 +302,7 @@ public void testFlush_WithQuantization() { final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 2) + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled) * 1) ); } } @@ -690,7 +693,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + when(quantizationService.train(eq(quantizationParams), any(Supplier.class), eq((long) vectorsPerField.get(i).size()))) .thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); @@ -729,7 +732,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + times(0) ); } } @@ -793,7 +796,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + when(quantizationService.train(eq(quantizationParams), any(Supplier.class), eq((long) vectorsPerField.get(i).size()))) .thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); @@ -832,7 +835,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), - times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + times(0) ); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index cdc372bda..a865fd35b 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -33,6 +33,7 @@ import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.util.function.Supplier; import java.io.IOException; import java.util.ArrayList; @@ -44,6 +45,7 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; @@ -325,7 +327,10 @@ public void testMerge_WithQuantization() { when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + // Fix mock to use the supplier + when(quantizationService.train(eq(quantizationParams), any(Supplier.class), eq((long) mergedVectors.size()))).thenReturn( + quantizationState + ); } catch (Exception e) { throw new RuntimeException(e); } @@ -349,13 +354,12 @@ public void testMerge_WithQuantization() { assertTrue(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() > 0L); knnVectorValuesFactoryMockedStatic.verify( () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues), - times(3) + times(2) ); } else { assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); verifyNoInteractions(nativeIndexWriter); } - } } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java index 61d3d7589..d94cce3a6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java @@ -97,7 +97,10 @@ public void testProcessAndReturnVector_withQuantization_success() throws IOExcep ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; knnVectorValues.nextDoc(); - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .build(); QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); when(buildIndexParams.getQuantizationState()).thenReturn(state); IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 690391dbd..c764875b5 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -20,10 +20,11 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.util.List; +import java.util.function.Supplier; public class QuantizationServiceTests extends KNNTestCase { private QuantizationService quantizationService; - private KNNVectorValues knnVectorValues; + private Supplier> knnVectorValues; @Before public void setUp() throws Exception { @@ -38,7 +39,8 @@ public void setUp() throws Exception { ); // Use the predefined vectors to create KNNVectorValues - knnVectorValues = KNNVectorValuesFactory.getVectorValues( + // Use the predefined vectors to create KNNVectorValues + knnVectorValues = () -> KNNVectorValuesFactory.getVectorValues( VectorDataType.FLOAT, new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) ); @@ -46,7 +48,11 @@ public void setUp() throws Exception { public void testTrain_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; @@ -62,7 +68,11 @@ public void testTrain_oneBitQuantizer_success() throws IOException { public void testTrain_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + twoBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -85,7 +95,11 @@ public void testTrain_twoBitQuantizer_success() throws IOException { public void testTrain_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + fourBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -110,7 +124,11 @@ public void testTrain_fourBitQuantizer_success() throws IOException { public void testQuantize_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); @@ -125,7 +143,11 @@ public void testQuantize_oneBitQuantizer_success() throws IOException { public void testQuantize_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + twoBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); @@ -138,7 +160,11 @@ public void testQuantize_twoBitQuantizer_success() throws IOException { public void testQuantize_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + fourBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); @@ -152,7 +178,11 @@ public void testQuantize_fourBitQuantizer_success() throws IOException { public void testQuantize_whenInvalidInput_thenThrows() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.get().totalLiveDocs() + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7a1da8781..55f8b3924 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -1712,7 +1712,10 @@ public void testANNWithQuantizationParams_thenSuccess() { quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); float[] meanThresholds = new float[] { 1.2f, 2.3f, 3.4f, 4.5f }; - QuantizationState quantizationState = new OneBitScalarQuantizationState(quantizationParams, meanThresholds); + QuantizationState quantizationState = OneBitScalarQuantizationState.builder() + .quantizationParams(quantizationParams) + .meanThresholds(meanThresholds) + .build(); try ( MockedConstruction quantizationCollectorMockedConstruction = Mockito.mockConstruction( diff --git a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java index 87cb57cdc..2a62c32f3 100644 --- a/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java +++ b/src/test/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateCacheTests.java @@ -16,6 +16,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -29,7 +30,6 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING; import static org.opensearch.knn.index.KNNSettings.QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING; -import static org.opensearch.knn.quantization.enums.ScalarQuantizationType.ONE_BIT; public class QuantizationStateCacheTests extends KNNTestCase { @@ -49,10 +49,10 @@ public void terminateThreadPool() { @SneakyThrows public void testSingleThreadedAddAndRetrieve() { String fieldName = "singleThreadField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -86,10 +86,11 @@ public void testMultiThreadedAddAndRetrieve() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "multiThreadField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); + String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -134,10 +135,10 @@ public void testMultiThreadedEvict() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "multiThreadEvictField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -184,10 +185,10 @@ public void testConcurrentAddAndEvict() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "concurrentAddEvictField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -243,10 +244,10 @@ public void testMultipleThreadedCacheClear() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "multiThreadField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -291,10 +292,10 @@ public void testRebuild() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "rebuildField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); String cacheSize = "10%"; TimeValue expiry = TimeValue.timeValueMinutes(30); @@ -338,10 +339,10 @@ public void testRebuildOnCacheSizeSettingsChange() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "rebuildField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); Settings settings = Settings.builder().build(); ClusterSettings clusterSettings = new ClusterSettings( @@ -391,10 +392,10 @@ public void testRebuildOnTimeExpirySettingsChange() { ExecutorService executorService = Executors.newFixedThreadPool(threadCount); CountDownLatch latch = new CountDownLatch(threadCount); String fieldName = "rebuildField"; - QuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ONE_BIT), - new float[] { 1.2f, 2.3f, 3.4f } - ); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 1.2f, 2.3f, 3.4f }) + .build(); Settings settings = Settings.builder().build(); ClusterSettings clusterSettings = new ClusterSettings( @@ -445,8 +446,14 @@ public void testCacheEvictionDueToSize() throws IOException { arr[i] = i; arr[i] = i + 1; } - QuantizationState state = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), arr); - QuantizationState state2 = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ONE_BIT), arr2); + QuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(arr) + .build(); + QuantizationState state2 = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(arr2) + .build(); long cacheSize = 1; Settings settings = Settings.builder().build(); ClusterSettings clusterSettings = new ClusterSettings( diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java index fa25e8e80..889b82adf 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -18,7 +18,10 @@ public class QuantizationStateSerializerTests extends KNNTestCase { public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = new float[] { 0.1f, 0.2f, 0.3f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .build(); // Serialize byte[] serialized = state.toByteArray(); diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 4fd4f40a6..e09707cab 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -21,7 +21,10 @@ public void testOneBitScalarQuantizationStateSerialization() throws IOException ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .build(); // Serialize byte[] serializedState = state.toByteArray(); @@ -33,6 +36,75 @@ public void testOneBitScalarQuantizationStateSerialization() throws IOException float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); + assertNull(deserializedState.getBelowThresholdMeans()); + assertNull(deserializedState.getAboveThresholdMeans()); + } + + // Test serialization and deserialization with optional fields + public void testOneBitScalarQuantizationState_WithOptionalFields() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + float[] belowThresholdMeans = { 0.5f, 1.5f, 2.5f }; + float[] aboveThresholdMeans = { 1.5f, 2.5f, 3.5f }; + + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .aboveThresholdMeans(aboveThresholdMeans) + .belowThresholdMeans(belowThresholdMeans) + .build(); + + // Serialize + byte[] serializedState = state.toByteArray(); + + // Deserialize + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); + + // Assertions + assertArrayEquals(mean, deserializedState.getMeanThresholds(), 0.0f); + assertArrayEquals(belowThresholdMeans, deserializedState.getBelowThresholdMeans(), 0.0f); + assertArrayEquals(aboveThresholdMeans, deserializedState.getAboveThresholdMeans(), 0.0f); + assertEquals(params, deserializedState.getQuantizationParams()); + } + + // Test handling of null arrays in RAM usage + public void testOneBitScalarQuantizationState_RamBytesUsedWithNulls() { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .build(); + + long actualRamBytesUsed = state.ramBytesUsed(); + long expectedRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class) + RamUsageEstimator + .shallowSizeOf(params) + RamUsageEstimator.sizeOf(mean); + + assertEquals(expectedRamBytesUsed, actualRamBytesUsed); + } + + // Test handling of all fields in RAM usage + public void testOneBitScalarQuantizationState_RamBytesUsedWithAllFields() { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + float[] belowThresholdMeans = { 0.5f, 1.5f, 2.5f }; + float[] aboveThresholdMeans = { 1.5f, 2.5f, 3.5f }; + + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .aboveThresholdMeans(aboveThresholdMeans) + .belowThresholdMeans(belowThresholdMeans) + .build(); + + long actualRamBytesUsed = state.ramBytesUsed(); + long expectedRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class) + RamUsageEstimator + .shallowSizeOf(params) + RamUsageEstimator.sizeOf(mean) + RamUsageEstimator.sizeOf(belowThresholdMeans) + RamUsageEstimator + .sizeOf(aboveThresholdMeans); + + assertEquals(expectedRamBytesUsed, actualRamBytesUsed); } public void testMultiBitScalarQuantizationStateSerialization() throws IOException { @@ -57,7 +129,10 @@ public void testSerializationWithDifferentVersions() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .build(); byte[] serializedState = state.toByteArray(); StreamInput in = StreamInput.wrap(serializedState); OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); @@ -71,7 +146,10 @@ public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(mean) + .build(); // 1. Manual Calculation of RAM Usage long manualEstimatedRamBytesUsed = 0L; @@ -84,6 +162,8 @@ public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException { // Mean array overhead (array header + size of elements) manualEstimatedRamBytesUsed += alignSize(16L + 4L * mean.length); + manualEstimatedRamBytesUsed += alignSize(4L); // belowThresholdMeans, even though it's null but it takes Object Header + manualEstimatedRamBytesUsed += alignSize(4L); // aboveThresholdMeans, even though it's null but it takes Object Header // 3. RAM Usage from RamUsageEstimator long expectedRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class) + RamUsageEstimator @@ -143,31 +223,46 @@ public void testOneBitScalarQuantizationStateGetDimensions_withDimensionNotMulti // Case 1: 5 dimensions (should align to 8) float[] thresholds1 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f }; - OneBitScalarQuantizationState state1 = new OneBitScalarQuantizationState(params, thresholds1); + OneBitScalarQuantizationState state1 = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(thresholds1) + .build(); int expectedDimensions1 = 8; // The next multiple of 8 assertEquals(expectedDimensions1, state1.getDimensions()); // Case 2: 7 dimensions (should align to 8) float[] thresholds2 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f }; - OneBitScalarQuantizationState state2 = new OneBitScalarQuantizationState(params, thresholds2); + OneBitScalarQuantizationState state2 = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(thresholds2) + .build(); int expectedDimensions2 = 8; // The next multiple of 8 assertEquals(expectedDimensions2, state2.getDimensions()); // Case 3: 8 dimensions (already aligned to 8) float[] thresholds3 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }; - OneBitScalarQuantizationState state3 = new OneBitScalarQuantizationState(params, thresholds3); + OneBitScalarQuantizationState state3 = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(thresholds3) + .build(); int expectedDimensions3 = 8; // Already aligned to 8 assertEquals(expectedDimensions3, state3.getDimensions()); // Case 4: 10 dimensions (should align to 16) float[] thresholds4 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f }; - OneBitScalarQuantizationState state4 = new OneBitScalarQuantizationState(params, thresholds4); + OneBitScalarQuantizationState state4 = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(thresholds4) + .build(); int expectedDimensions4 = 16; // The next multiple of 8 assertEquals(expectedDimensions4, state4.getDimensions()); // Case 5: 16 dimensions (already aligned to 16) float[] thresholds5 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, 12.5f, 13.5f, 14.5f, 15.5f }; - OneBitScalarQuantizationState state5 = new OneBitScalarQuantizationState(params, thresholds5); + OneBitScalarQuantizationState state5 = OneBitScalarQuantizationState.builder() + .quantizationParams(params) + .meanThresholds(thresholds5) + .build(); int expectedDimensions5 = 16; // Already aligned to 16 assertEquals(expectedDimensions5, state5.getDimensions()); } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java index de815d8ad..423341ae6 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -223,5 +223,10 @@ public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) { public float[] getVectorAtThePosition(int position) { return vectors[position]; } + + @Override + public void resetVectorValues() { + // No-op + } } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index a6b907ccb..25296b386 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -31,6 +31,11 @@ public void testTrain_withTrainingRequired() throws IOException { public float[] getVectorAtThePosition(int position) { return vectors[position]; } + + @Override + public void resetVectorValues() { + // No-op + } }; OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); QuantizationState state = quantizer.train(originalRequest); @@ -40,13 +45,42 @@ public float[] getVectorAtThePosition(int position) { assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); } + public void testTrain_withBelowAboveThresholdMeans() throws IOException { + float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; + TrainingRequest trainingRequest = new TrainingRequest<>(vectors.length) { + @Override + public float[] getVectorAtThePosition(int position) { + return vectors[position]; + } + + @Override + public void resetVectorValues() { + // No-op + } + }; + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationState state = quantizer.train(trainingRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) state; + + float[] expectedMeanThresholds = { 4.0f, 5.0f, 6.0f }; + assertArrayEquals(expectedMeanThresholds, oneBitState.getMeanThresholds(), 0.001f); + + // Validate below and above thresholds + float[] expectedBelowThresholdMeans = { 2.5f, 3.5f, 4.5f }; + float[] expectedAboveThresholdMeans = { 7.0f, 8.0f, 9.0f }; + assertArrayEquals(expectedBelowThresholdMeans, oneBitState.getBelowThresholdMeans(), 0.001f); + assertArrayEquals(expectedAboveThresholdMeans, oneBitState.getAboveThresholdMeans(), 0.001f); + } + public void testQuantize_withState() throws IOException { float[] vector = { 3.0f, 6.0f, 9.0f }; float[] thresholds = { 4.0f, 5.0f, 6.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - thresholds - ); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .build(); OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); BinaryQuantizationOutput output = new BinaryQuantizationOutput(1); @@ -60,10 +94,10 @@ public void testQuantize_withState() throws IOException { public void testQuantize_withNullVector() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - new float[] { 0.0f } - ); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(new float[] { 0.0f }) + .build(); BinaryQuantizationOutput output = new BinaryQuantizationOutput(1); expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state, output)); } @@ -110,14 +144,136 @@ public void testQuantize_withMismatchedDimensions() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); float[] vector = { 1.0f, 2.0f, 3.0f }; float[] thresholds = { 4.0f, 5.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - thresholds - ); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .build(); QuantizationOutput output = new BinaryQuantizationOutput(1); expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output)); } + public void testTrain_withRotationApplied() throws IOException { + float[][] vectors = { { 10.0f, 200.0f, 3000.0f }, { 4000.0f, 5000.0f, 6000.0f }, { 7000.0f, 8000.0f, 9000.0f } }; + + TrainingRequest trainingRequest = new TrainingRequest<>(vectors.length) { + @Override + public float[] getVectorAtThePosition(int position) { + return vectors[position]; + } + + @Override + public void resetVectorValues() { + // No-op + } + }; + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + OneBitScalarQuantizationState state = (OneBitScalarQuantizationState) quantizer.train(trainingRequest); + + assertNotNull(state); + assertNotNull(state.getRotationMatrix()); + assertTrue(state.getRotationMatrix().length > 0); + } + + public void testTrain_withoutRotationMatrix() throws IOException { + float[][] vectors = { { 1.0f, 1.0f, 1.0f }, { 1.1f, 1.1f, 1.1f }, { 0.9f, 0.9f, 0.9f } }; + + TrainingRequest trainingRequest = new TrainingRequest<>(vectors.length) { + @Override + public float[] getVectorAtThePosition(int position) { + return vectors[position]; + } + + @Override + public void resetVectorValues() { + // No-op + } + }; + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + OneBitScalarQuantizationState state = (OneBitScalarQuantizationState) quantizer.train(trainingRequest); + + assertNotNull(state); + assertNull(state.getRotationMatrix()); + } + + public void testQuantize_withRotationMatrix() { + float[] vector = { 3.0f, 6.0f, 9.0f }; + float[] thresholds = { 4.0f, 5.0f, 6.0f }; + + // Generate a rotation matrix + float[][] rotationMatrix = RandomGaussianRotation.generateRotationMatrix(3); + + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .rotationMatrix(rotationMatrix) + .build(); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(1); + + quantizer.quantize(vector, state, output); + + assertNotNull(output); + assertNotNull(output.getQuantizedVector()); + } + + public void testQuantize_withDifferentRotationMatrices() { + float[] vector = { 3.0f, 6.0f, 9.0f }; + float[] thresholds = { 4.0f, 5.0f, 6.0f }; + + // Generate two different rotation matrices + float[][] rotationMatrix1 = RandomGaussianRotation.generateRotationMatrix(3); + float[][] rotationMatrix2 = RandomGaussianRotation.generateRotationMatrix(3); + + OneBitScalarQuantizationState state1 = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .rotationMatrix(rotationMatrix1) + .build(); + + OneBitScalarQuantizationState state2 = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .rotationMatrix(rotationMatrix2) + .build(); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + BinaryQuantizationOutput output1 = new BinaryQuantizationOutput(1); + BinaryQuantizationOutput output2 = new BinaryQuantizationOutput(1); + + quantizer.quantize(vector, state1, output1); + quantizer.quantize(vector, state2, output2); + + assertNotNull(output1.getQuantizedVector()); + assertNotNull(output2.getQuantizedVector()); + assertFalse(output1.getQuantizedVector().equals(output2.getQuantizedVector())); + } + + public void testRotationConsistency() { + float[] vector = { 5.0f, 10.0f, 15.0f }; + float[] thresholds = { 6.0f, 11.0f, 16.0f }; + + // Generate a fixed rotation matrix + float[][] rotationMatrix = RandomGaussianRotation.generateRotationMatrix(3); + + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .rotationMatrix(rotationMatrix) + .build(); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + BinaryQuantizationOutput output1 = new BinaryQuantizationOutput(1); + BinaryQuantizationOutput output2 = new BinaryQuantizationOutput(1); + + quantizer.quantize(vector, state, output1); + quantizer.quantize(vector, state, output2); + + assertArrayEquals(output1.getQuantizedVector(), output2.getQuantizedVector()); + } + public void testCalculateMean() throws IOException { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; @@ -127,12 +283,21 @@ public void testCalculateMean() throws IOException { public float[] getVectorAtThePosition(int position) { return vectors[position]; } + + @Override + public void resetVectorValues() { + // No-op + } }; Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); int[] sampledIndices = sampler.sample(vectors.length, 3); - float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices); - assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); + OneBitScalarQuantizationState oneBitScalarQuantizationState = QuantizerHelper.calculateQuantizationState( + samplingRequest, + sampledIndices, + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT) + ); + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, oneBitScalarQuantizationState.getMeanThresholds(), 0.001f); } public void testCalculateMean_withNullVector() { @@ -144,20 +309,29 @@ public void testCalculateMean_withNullVector() { public float[] getVectorAtThePosition(int position) { return vectors[position]; } + + @Override + public void resetVectorValues() { + // No-op + } }; Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); int[] sampledIndices = sampler.sample(vectors.length, 3); - expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices)); + ScalarQuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + expectThrows( + IllegalArgumentException.class, + () -> QuantizerHelper.calculateQuantizationState(samplingRequest, sampledIndices, quantizationParams) + ); } public void testQuantize_withState_multiple_times() throws IOException { float[] vector = { 3.0f, 6.0f, 9.0f }; float[] thresholds = { 4.0f, 5.0f, 6.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - thresholds - ); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .build(); OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); BinaryQuantizationOutput output = new BinaryQuantizationOutput(1); @@ -174,7 +348,10 @@ public void testQuantize_withState_multiple_times() throws IOException { // Modify vector and thresholds for a second quantization call vector = new float[] { 7.0f, 8.0f, 9.0f }; thresholds = new float[] { 6.0f, 7.0f, 8.0f }; - state = new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), thresholds); + state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .build(); // Second quantization output.prepareQuantizedVector(vector.length); // Ensure it is prepared for the new vector @@ -191,10 +368,10 @@ public void testQuantize_withState_multiple_times() throws IOException { public void testQuantize_ReuseByteArray() throws IOException { float[] vector = { 3.0f, 6.0f, 9.0f }; float[] thresholds = { 4.0f, 5.0f, 6.0f }; - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( - new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), - thresholds - ); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .build(); OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); BinaryQuantizationOutput output = new BinaryQuantizationOutput(1); @@ -226,7 +403,10 @@ public void testQuantize_withMultipleVectors_inLoop() throws IOException { float[] thresholds = { 1.5f, 2.5f, 3.5f, 4.5f }; ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, thresholds); + OneBitScalarQuantizationState state = OneBitScalarQuantizationState.builder() + .quantizationParams(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)) + .meanThresholds(thresholds) + .build(); BinaryQuantizationOutput output = new BinaryQuantizationOutput(1); diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotationTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotationTests.java new file mode 100644 index 000000000..4b9e6918d --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/RandomGaussianRotationTests.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; + +import java.util.HashSet; +import java.util.Set; + +public class RandomGaussianRotationTests extends KNNTestCase { + + public void testGenerateRotationMatrix_Orthogonality() { + int dimensions = 5; + + // Generate the rotation matrix + float[][] rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dimensions); + + // Validate dimensions + assertEquals(dimensions, rotationMatrix.length); + for (float[] row : rotationMatrix) { + assertEquals(dimensions, row.length); + } + + // Validate orthogonality: dot product of distinct rows should be close to zero + float delta = 0.0001f; + for (int i = 0; i < dimensions; i++) { + for (int j = i + 1; j < dimensions; j++) { + float dotProduct = 0f; + for (int k = 0; k < dimensions; k++) { + dotProduct += rotationMatrix[i][k] * rotationMatrix[j][k]; + } + assertEquals("Dot product of row " + i + " and row " + j + " is not zero", 0.0f, dotProduct, delta); + } + } + + // Validate normalization: length of each row vector should be close to 1 + for (int i = 0; i < dimensions; i++) { + float norm = 0f; + for (int j = 0; j < dimensions; j++) { + norm += rotationMatrix[i][j] * rotationMatrix[i][j]; + } + assertEquals("Row " + i + " is not normalized", 1.0f, (float) Math.sqrt(norm), delta); + } + } + + public void testApplyRotation() { + float[] vector = { 1.0f, 0.0f, 0.0f }; + int dimensions = vector.length; + + // Generate a rotation matrix + float[][] rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dimensions); + + // Apply the rotation + float[] rotatedVector = RandomGaussianRotation.applyRotation(vector, rotationMatrix); + + // Validate dimensions + assertEquals(dimensions, rotatedVector.length); + + // Validate that the rotated vector is non-zero + float norm = 0f; + for (float value : rotatedVector) { + norm += value * value; + } + assertTrue("Rotated vector should not be zero", norm > 0); + + // Validate that the rotated vector lies in the same dimensional space + Set nonZeroIndices = new HashSet<>(); + for (int i = 0; i < dimensions; i++) { + if (Math.abs(rotatedVector[i]) > 0.0001f) { + nonZeroIndices.add(i); + } + } + assertFalse("Rotated vector contains invalid values", nonZeroIndices.isEmpty()); + } + + public void testOrthogonalityOfGeneratedMatrixWithLargerDimensions() { + int dimensions = 10; + + // Generate the rotation matrix + float[][] rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dimensions); + + // Validate orthogonality: dot product of distinct rows should be close to zero + float delta = 0.0001f; + for (int i = 0; i < dimensions; i++) { + for (int j = i + 1; j < dimensions; j++) { + float dotProduct = 0f; + for (int k = 0; k < dimensions; k++) { + dotProduct += rotationMatrix[i][k] * rotationMatrix[j][k]; + } + assertEquals("Dot product of row " + i + " and row " + j + " is not zero", 0.0f, dotProduct, delta); + } + } + } + + public void testRotationMatrixCorrectness() { + float[] vector = { 3.0f, 4.0f }; + int dimensions = vector.length; + + // Generate the rotation matrix + float[][] rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dimensions); + + // Apply the rotation + float[] rotatedVector = RandomGaussianRotation.applyRotation(vector, rotationMatrix); + + // Ensure rotated vector length matches original vector length + float originalNorm = 0f; + for (float value : vector) { + originalNorm += value * value; + } + + float rotatedNorm = 0f; + for (float value : rotatedVector) { + rotatedNorm += value * value; + } + + assertEquals( + "Rotated vector norm does not match original vector norm", + (float) Math.sqrt(originalNorm), + (float) Math.sqrt(rotatedNorm), + 0.0001f + ); + } +}