Skip to content

Commit

Permalink
QuantizationFramework Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vikasht34 committed Jul 8, 2024
1 parent e593397 commit 3fb3686
Show file tree
Hide file tree
Showing 33 changed files with 1,165 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.quantization.QuantizationManager;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -56,24 +61,28 @@ private static void createNativeIndex(
}

private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[]> kNNVectorValues) throws IOException {
List<float[]> vectorList = new ArrayList<>();
List<byte[]> vectorList = new ArrayList<>();
List<Integer> docIdList = new ArrayList<>();
long vectorAddress = 0;
int dimension = 0;
long totalLiveDocs = kNNVectorValues.totalLiveDocs();
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
long vectorsPerTransfer = Integer.MIN_VALUE;

QuantizationParams params = getQuantizationParams(); // Implement this method to get appropriate params
Quantizer<float[], byte[]> quantizer = (Quantizer<float[], byte[]>) QuantizationManager.getInstance().getQuantizer(params);

KNNVectorValuesIterator iterator = kNNVectorValues.getVectorValuesIterator();

for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
float[] temp = kNNVectorValues.getVector();
// This temp object and copy of temp object is required because when we map floats we read to a memory
// location in heap always for floatVectorValues. Ref: OffHeapFloatVectorValues.vectorValue.
float[] vector = Arrays.copyOf(temp, temp.length);
byte[] quantizedVector = quantizer.quantize(vector).getQuantizedVector();
dimension = vector.length;
if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
vectorsPerTransfer = (dimension * Byte.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
// This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer
// Doing this will reduce 1 extra trip to JNI layer.
if (vectorsPerTransfer == 0) {
Expand All @@ -82,19 +91,19 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[
}

if (vectorList.size() == vectorsPerTransfer) {
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * dimension);
// We should probably come up with a better way to reuse the vectorList memory which we have
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
vectorList = new ArrayList<>();
}

vectorList.add(vector);
vectorList.add(quantizedVector);
docIdList.add(doc);
}

if (vectorList.isEmpty() == false) {
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * dimension);
}
// SerializationMode.COLLECTION_OF_FLOATS is not getting used. I just added it to ensure code successfully
// works.
Expand All @@ -105,4 +114,9 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[
SerializationMode.COLLECTION_OF_FLOATS
);
}

private static QuantizationParams getQuantizationParams() {
// Implement this method to return appropriate quantization parameters based on your use case
return new SQParams(SQTypes.ONE_BIT); // Example, modify as needed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ public void close() throws IOException {
public long ramBytesUsed() {
return 0;
}

private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private final FieldInfo fieldInfo;
private final List<T> vectors;
Expand Down
20 changes: 17 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.quantization.QuantizationManager;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.io.IOException;
import java.nio.file.Path;
Expand Down Expand Up @@ -154,6 +159,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return convertSearchResponseToScorer(docIdsToScoreMap);
}

private QuantizationParams getQuantizationParams() {
// Implement this method to return appropriate quantization parameters based on your use case
return new SQParams(SQTypes.ONE_BIT); // Example, modify as needed
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
Expand Down Expand Up @@ -211,6 +221,9 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

QuantizationParams params = getQuantizationParams(); // Implement this method to get appropriate params
Quantizer<float[], byte[]> quantizer = (Quantizer<float[], byte[]>) QuantizationManager.getInstance().getQuantizer(params);

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
Expand Down Expand Up @@ -272,7 +285,7 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
spaceType,
knnEngine,
knnQuery.getIndexName(),
FieldInfoExtractor.getIndexDescription(fieldInfo)
"B" + FieldInfoExtractor.getIndexDescription(fieldInfo)
),
knnQuery.getIndexName(),
modelId
Expand All @@ -295,10 +308,11 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
throw new RuntimeException("Index has already been closed");
}
int[] parentIds = getParentIdsArray(context);
byte[] quantizedVector = quantizer.quantize(knnQuery.getQueryVector()).getQuantizedVector();
if (knnQuery.getK() > 0) {
results = JNIService.queryIndex(
results = JNIService.queryBinaryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
quantizedVector,
knnQuery.getK(),
knnEngine,
filterIds,
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public static void createIndex(
}

if (KNNEngine.FAISS == knnEngine) {
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) {
String indexDesc = (String) parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER);
parameters.put(KNNConstants.INDEX_DESCRIPTION_PARAMETER ,"B" + indexDesc);

}
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) {
FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization;

import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.knn.quantization.sampler.Sampler;
import org.opensearch.knn.quantization.sampler.SamplingFactory;

public class QuantizationManager {
private static QuantizationManager instance;

private QuantizationManager() {}

public static QuantizationManager getInstance() {
if (instance == null) {
instance = new QuantizationManager();
}
return instance;
}
public <T, R> QuantizationState train(TrainingRequest<T> trainingRequest) {
Quantizer<T, R> quantizer = (Quantizer<T, R>) getQuantizer(trainingRequest.getParams());
int sampleSize = quantizer.getSamplingSize();
Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
TrainingRequest<T> sampledRequest = new SamplingTrainingRequest<>(trainingRequest, sampler, sampleSize);
return quantizer.train(sampledRequest);
}
public Quantizer<?, ?> getQuantizer(QuantizationParams params) {
return QuantizerFactory.getQuantizer(params);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

public enum QuantizationType {
SPACE_QUANTIZATION,
VALUE_QUANTIZATION,
}
21 changes: 21 additions & 0 deletions src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

public enum SQTypes {
FP16,
INT8,
INT6,
INT4,
ONE_BIT,
TWO_BIT
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

public enum ValueQuantizationType {
SQ
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;

public class QuantizerFactory {
static {
// Register all quantizers here
QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new);
}

public static Quantizer<?, ?> getQuantizer(QuantizationParams params) {
if (params instanceof SQParams) {
SQParams sqParams = (SQParams) params;
return QuantizerRegistry.getQuantizer(params, sqParams.getSqType().name());
}
// Add more cases for other quantization parameters here
throw new IllegalArgumentException("Unsupported quantization parameters: " + params.getClass().getName());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

public class QuantizerRegistry {
private static final Map<Class<? extends QuantizationParams>, Map<String, Supplier<? extends Quantizer<?, ?>>>> registry = new HashMap<>();

public static <T extends QuantizationParams> void register(Class<T> paramClass, String typeIdentifier, Supplier<? extends Quantizer<?, ?>> quantizerSupplier) {
registry.computeIfAbsent(paramClass, k -> new HashMap<>()).put(typeIdentifier, quantizerSupplier);
}

public static Quantizer<?, ?> getQuantizer(QuantizationParams params, String typeIdentifier) {
Map<String, Supplier<? extends Quantizer<?, ?>>> typeMap = registry.get(params.getClass());
if (typeMap == null) {
throw new IllegalArgumentException("No quantizer registered for parameters: " + params.getClass().getName());
}
Supplier<? extends Quantizer<?, ?>> supplier = typeMap.get(typeIdentifier);
if (supplier == null) {
throw new IllegalArgumentException("No quantizer registered for type identifier: " + typeIdentifier);
}
return supplier.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationOutput;

public class OneBitScalarQuantizationOutput extends QuantizationOutput<byte[]> {

public OneBitScalarQuantizationOutput(byte[] quantizedVector) {
super(quantizedVector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationOutput;

public abstract class QuantizationOutput<T> {
private final T quantizedVector;

public QuantizationOutput(T quantizedVector) {
this.quantizedVector = quantizedVector;
}

public T getQuantizedVector() {
return quantizedVector;
}
}
Loading

0 comments on commit 3fb3686

Please sign in to comment.