Skip to content

Commit

Permalink
Implementation of Random Rotation Matrix with Query Scaling matix of …
Browse files Browse the repository at this point in the history
…above and below thresholds for 1 bit Binary Quantization

Signed-off-by: Vikasht34 <[email protected]>
  • Loading branch information
Vikasht34 committed Feb 12, 2025
1 parent 359a37b commit 72ce48f
Show file tree
Hide file tree
Showing 23 changed files with 949 additions and 143 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Bump Faiss commit from 1f42e81 to 0cbc2a8 to accelerate hamming distance calculation using _mm512_popcnt_epi64 intrinsic and also add avx512-fp16 instructions to boost performance [#2381](https://github.com/opensearch-project/k-NN/pull/2381)
* Enabled indices.breaker.total.use_real_memory setting via build.gradle for integTest Cluster to catch heap CB in local ITs and github CI actions [#2395](https://github.com/opensearch-project/k-NN/pull/2395/)
* Enabled idempotency of local builds when using `./gradlew clean` and nest `jni/release` directory under `jni/build` for easier cleanup [#2516](https://github.com/opensearch-project/k-NN/pull/2516)
### Refactoring
### Refactoring
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -595,4 +595,4 @@ task updateVersion {
ant.replaceregexp(file:".github/workflows/backwards_compatibility_tests_workflow.yml", match: oldBWCVersion, replace: oldBWCVersion + '", "' + opensearch_version.tokenize('-')[0], flags:'g', byline:true)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ private QuantizationState train(
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationState = quantizationService.train(quantizationParams, knnVectorValuesSupplier, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@
import java.io.IOException;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.function.Supplier;

/**
* KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class.
* It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID.
*/
@Log4j2
final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {

private final KNNVectorValues<T> knnVectorValues;
private final Supplier<KNNVectorValues<T>> knnVectorValuesSupplier;
private KNNVectorValues<T> knnVectorValues;
private int lastIndex;

/**
* Constructs a new QuantizationFloatVectorTrainingRequest.
*
* @param knnVectorValues the KNNVectorValues instance containing the vectors.
* @param knnVectorValuesSupplier the KNNVectorValues instance containing the vectors.
*/
KNNVectorQuantizationTrainingRequest(KNNVectorValues<T> knnVectorValues, long liveDocs) {
KNNVectorQuantizationTrainingRequest(Supplier<KNNVectorValues<T>> knnVectorValuesSupplier, long liveDocs) {
super((int) liveDocs);
this.knnVectorValues = knnVectorValues;
this.knnVectorValuesSupplier = knnVectorValuesSupplier;
resetVectorValues(); // Initialize the first instance
this.lastIndex = 0;
}

Expand All @@ -52,4 +54,13 @@ public T getVectorAtThePosition(int position) throws IOException {
// Return the vector
return knnVectorValues.getVector();
}

/**
* Resets the KNNVectorValues to enable a fresh iteration by calling the supplier again.
*/
@Override
public void resetVectorValues() {
this.knnVectorValues = knnVectorValuesSupplier.get();
this.lastIndex = 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;
import java.util.function.Supplier;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

Expand Down Expand Up @@ -53,19 +54,22 @@ public static <T, R> QuantizationService<T, R> getInstance() {
* {@link QuantizationState}. The quantizer is determined based on the given {@link QuantizationParams}.
*
* @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization.
* @param knnVectorValues The {@link KNNVectorValues} representing the vector data to be used for training.
* @param knnVectorValuesSupplier The {@link KNNVectorValues} representing the vector data to be used for training.
* @return The {@link QuantizationState} containing the state of the trained quantizer.
* @throws IOException If an I/O error occurs during the training process.
*/
public QuantizationState train(
final QuantizationParams quantizationParams,
final KNNVectorValues<T> knnVectorValues,
final Supplier<KNNVectorValues<T>> knnVectorValuesSupplier,
final long liveDocs
) throws IOException {
Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationParams);

// Create the training request from the vector values
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs);
// Create the training request using the supplier
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(
knnVectorValuesSupplier,
liveDocs
);

// Train the quantizer and return the quantization state
return quantizer.train(trainingRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

package org.opensearch.knn.quantization.models.quantizationState;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.*;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;

import java.io.IOException;
Expand All @@ -21,10 +20,12 @@
* including the mean values used for quantization.
*/
@Getter
@NoArgsConstructor // No-argument constructor for deserialization
@Builder
@AllArgsConstructor
@NoArgsConstructor(force = true)
public final class OneBitScalarQuantizationState implements QuantizationState {
private ScalarQuantizationParams quantizationParams;
@NonNull
private final ScalarQuantizationParams quantizationParams;
/**
* Mean thresholds used in the quantization process.
* Each threshold value corresponds to a dimension of the vector being quantized.
Expand All @@ -33,7 +34,27 @@ public final class OneBitScalarQuantizationState implements QuantizationState {
* If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0],
* The quantized vector will be [0, 1, 1].
*/
private float[] meanThresholds;
@NonNull
private final float[] meanThresholds;

/**
* Represents the mean of all values below the threshold for each dimension.
*/
@Builder.Default
private float[] belowThresholdMeans = null;

/**
* Represents the mean of all values above the threshold for each dimension.
*/
@Builder.Default
private float[] aboveThresholdMeans = null;
@Builder.Default
private double averageL2L1Ratio = 0.0;
/**
* Rotation matrix used when L2/L1 ratio > 0.6
*/
@Builder.Default
private float[][] rotationMatrix = null;

@Override
public ScalarQuantizationParams getQuantizationParams() {
Expand All @@ -48,9 +69,23 @@ public ScalarQuantizationParams getQuantizationParams() {
*/
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(Version.CURRENT.id); // Write the version
out.writeVInt(Version.CURRENT.id); // Write the versionF
quantizationParams.writeTo(out);
out.writeFloatArray(meanThresholds);
out.writeOptionalArray(belowThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(belowThresholdMeans) } : null);
// Serialize aboveThresholdMeans using writeOptionalArray
out.writeOptionalArray(aboveThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(aboveThresholdMeans) } : null);
out.writeOptionalDouble(averageL2L1Ratio);
// Write rotation matrix
if (rotationMatrix != null) {
out.writeBoolean(true);
out.writeVInt(rotationMatrix.length);
for (float[] row : rotationMatrix) {
out.writeFloatArray(row);
}
} else {
out.writeBoolean(false);
}
}

/**
Expand All @@ -63,6 +98,23 @@ public OneBitScalarQuantizationState(StreamInput in) throws IOException {
int version = in.readVInt(); // Read the version
this.quantizationParams = new ScalarQuantizationParams(in, version);
this.meanThresholds = in.readFloatArray();
if (Version.fromId(version).onOrAfter(Version.V_3_0_0)) {
// Deserialize belowThresholdMeans using readOptionalArray
FloatArrayWrapper[] wrappedBelowThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new);
this.belowThresholdMeans = wrappedBelowThresholdMeans != null ? wrappedBelowThresholdMeans[0].getArray() : null;
// Deserialize aboveThresholdMeans using readOptionalArray
FloatArrayWrapper[] wrappedAboveThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new);
this.aboveThresholdMeans = wrappedAboveThresholdMeans != null ? wrappedAboveThresholdMeans[0].getArray() : null;
this.averageL2L1Ratio = in.readOptionalDouble();
// Read rotation matrix
if (in.readBoolean()) {
int dimensions = in.readVInt();
this.rotationMatrix = new float[dimensions][];
for (int i = 0; i < dimensions; i++) {
this.rotationMatrix[i] = in.readFloatArray();
}
}
}
}

/**
Expand Down Expand Up @@ -139,6 +191,41 @@ public long ramBytesUsed() {
long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class);
size += RamUsageEstimator.shallowSizeOf(quantizationParams);
size += RamUsageEstimator.sizeOf(meanThresholds);
if (belowThresholdMeans != null) {
size += RamUsageEstimator.sizeOf(belowThresholdMeans);
}
if (aboveThresholdMeans != null) {
size += RamUsageEstimator.sizeOf(aboveThresholdMeans);
}
if (rotationMatrix != null) {
size += RamUsageEstimator.shallowSizeOf(rotationMatrix);
// Add size of each row array
for (float[] row : rotationMatrix) {
size += RamUsageEstimator.sizeOf(row);
}
}
return size;
}

private class FloatArrayWrapper implements Writeable {
private final float[] array;

public FloatArrayWrapper(float[] array) {
this.array = array;
}

// Constructor that matches Writeable.Reader<T>
public FloatArrayWrapper(StreamInput in) throws IOException {
this.array = in.readFloatArray();
}

public float[] getArray() {
return array;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeFloatArray(array);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ public abstract class TrainingRequest<T> {
* @return the vector corresponding to the specified document ID.
*/
public abstract T getVectorAtThePosition(int position) throws IOException;

public abstract void resetVectorValues();
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) {
@Override
public QuantizationState train(final TrainingRequest<float[]> trainingRequest) throws IOException {
int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds);
return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds);
return QuantizerHelper.calculateQuantizationState(
trainingRequest,
sampledDocIds,
new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)
);
}

/**
Expand All @@ -73,7 +76,7 @@ public QuantizationState train(final TrainingRequest<float[]> trainingRequest) t
* @param output the QuantizationOutput object to store the quantized representation of the vector.
*/
@Override
public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput<byte[]> output) {
public void quantize(float[] vector, final QuantizationState state, final QuantizationOutput<byte[]> output) {
if (vector == null) {
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
Expand All @@ -84,6 +87,10 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds.length != vectorLength) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
float[][] rotationMatrix = binaryState.getRotationMatrix();
if (rotationMatrix != null) {
vector = RandomGaussianRotation.applyRotation(vector, rotationMatrix);
}
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
}
Expand Down
Loading

0 comments on commit 72ce48f

Please sign in to comment.