diff --git a/CHANGELOG.md b/CHANGELOG.md
index 76eeb1e447..dcb714a58d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -27,4 +27,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
-* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
\ No newline at end of file
+* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
+* Quantization Framework For Disk Optimized Vector Search and Implementation of Binary 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java
new file mode 100644
index 0000000000..4a2a17a574
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+/**
+ * The QuantizationType enum represents the different types of quantization
+ * that can be applied in the KNN.
+ *
+ *
+ * - SPACE_QUANTIZATION: This type of quantization focuses on the space
+ * or the representation of the data vectors. It is commonly used for techniques
+ * that reduce the dimensionality or discretize the data space.
+ * - VALUE_QUANTIZATION: This type of quantization focuses on the values
+ * within the data vectors. It involves mapping continuous values into discrete
+ * values, which can be useful for compressing data or reducing the precision
+ * of the representation.
+ *
+ */
+public enum QuantizationType {
+ /**
+ * Represents space quantization, typically involving dimensionality reduction
+ * or space partitioning techniques.
+ */
+ SPACE,
+
+ /**
+ * Represents value quantization, typically involving the conversion of continuous
+ * values into discrete ones.
+ */
+ VALUE,
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java
new file mode 100644
index 0000000000..88290c6a86
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+/**
+ * The SQTypes enum defines the various scalar quantization types that can be used
+ * in the KNN for vector quantization.
+ * Each type corresponds to a different bit-width representation of the quantized values.
+ */
+public enum ScalarQuantizationType {
+ /**
+ * ONE_BIT quantization uses a single bit per coordinate.
+ */
+ ONE_BIT,
+
+ /**
+ * TWO_BIT quantization uses two bits per coordinate.
+ */
+ TWO_BIT,
+
+ /**
+ * FOUR_BIT quantization uses four bits per coordinate.
+ */
+ FOUR_BIT,
+
+ /**
+ * UNSUPPORTED_TYPE is used to denote quantization types that are not supported.
+ * This can be used as a placeholder or default value.
+ */
+ UNSUPPORTED_TYPE
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java
new file mode 100644
index 0000000000..43db46cf6e
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java
@@ -0,0 +1,18 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+/**
+ * The ValueQuantizationType enum defines the types of value quantization techniques
+ * that can be applied in the KNN.
+ */
+public enum ValueQuantizationType {
+ /**
+ * SQ (Scalar Quantization) represents a method where each coordinate of the vector is quantized
+ * independently.
+ */
+ SCALAR
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
new file mode 100644
index 0000000000..985efd4cd1
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * The QuantizerFactory class is responsible for creating instances of {@link Quantizer}
+ * based on the provided {@link QuantizationParams}. It uses a registry to look up the
+ * appropriate quantizer implementation for the given quantization parameters.
+ */
+public final class QuantizerFactory {
+ private static final AtomicBoolean isRegistered = new AtomicBoolean(false);
+
+ // Private constructor to prevent instantiation
+ private QuantizerFactory() {}
+
+ /**
+ * Ensures that default quantizers are registered.
+ */
+ private static void ensureRegistered() {
+ if (!isRegistered.get()) {
+ synchronized (QuantizerFactory.class) {
+ if (!isRegistered.get()) {
+ QuantizerRegistrar.registerDefaultQuantizers();
+ isRegistered.set(true);
+ }
+ }
+ }
+ }
+
+ /**
+ * Retrieves a quantizer instance based on the provided quantization parameters.
+ *
+ * @param params the quantization parameters used to determine the appropriate quantizer
+ * @param the type of quantization parameters, extending {@link QuantizationParams}
+ * @param the type of the quantized output
+ * @return an instance of {@link Quantizer} corresponding to the provided parameters
+ */
+ public static Quantizer
getQuantizer(final P params) {
+ if (params == null) {
+ throw new IllegalArgumentException("Quantization parameters must not be null.");
+ }
+ // Lazy Registration instead of static block as class level;
+ ensureRegistered();
+ return QuantizerRegistry.getQuantizer(params);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java
new file mode 100644
index 0000000000..c8a2eb2bf6
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+
+/**
+ * The QuantizerRegistrar class is responsible for registering default quantizers.
+ * This class ensures that the registration happens only once in a thread-safe manner.
+ */
+final class QuantizerRegistrar {
+
+ // Private constructor to prevent instantiation
+ private QuantizerRegistrar() {}
+
+ /**
+ * Registers default quantizers if not already registered.
+ *
+ * This method is synchronized to ensure that registration occurs only once,
+ * even in a multi-threaded environment.
+ *
+ */
+ public static synchronized void registerDefaultQuantizers() {
+ // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT
+ QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new);
+ // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE,
+ ScalarQuantizationType.TWO_BIT,
+ () -> new MultiBitScalarQuantizer(2)
+ );
+ // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE,
+ ScalarQuantizationType.FOUR_BIT,
+ () -> new MultiBitScalarQuantizer(4)
+ );
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
new file mode 100644
index 0000000000..1243d79eff
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Supplier;
+
+/**
+ * The QuantizerRegistry class is responsible for managing the registration and retrieval
+ * of quantizer instances. Quantizers are registered with specific quantization parameters
+ * and type identifiers, allowing for efficient lookup and instantiation.
+ */
+final class QuantizerRegistry {
+
+ // Private constructor to prevent instantiation
+ private QuantizerRegistry() {}
+
+ // ConcurrentHashMap for thread-safe access
+ private static final Map>> registry = new ConcurrentHashMap<>();
+
+ /**
+ * Registers a quantizer with the registry.
+ *
+ * @param paramClass the class of the quantization parameters
+ * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION)
+ * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT)
+ * @param quantizerSupplier a supplier that provides instances of the quantizer
+ * @param the type of quantization parameters
+ */
+ public static
void register(
+ final Class
paramClass,
+ final QuantizationType quantizationType,
+ final ScalarQuantizationType sqType,
+ final Supplier extends Quantizer, ?>> quantizerSupplier
+ ) {
+ String identifier = createIdentifier(quantizationType, sqType);
+ // Ensure that the quantizer for this identifier is registered only once
+ registry.computeIfAbsent(identifier, key -> quantizerSupplier);
+ }
+
+ /**
+ * Retrieves a quantizer instance based on the provided quantization parameters.
+ *
+ * @param params the quantization parameters used to determine the appropriate quantizer
+ * @param
the type of quantization parameters
+ * @param the type of the quantized output
+ * @return an instance of {@link Quantizer} corresponding to the provided parameters
+ * @throws IllegalArgumentException if no quantizer is registered for the given parameters
+ */
+ public static Quantizer
getQuantizer(final P params) {
+ String identifier = params.getTypeIdentifier();
+ Supplier extends Quantizer, ?>> supplier = registry.get(identifier);
+ if (supplier == null) {
+ throw new IllegalArgumentException(
+ "No quantizer registered for type identifier: " + identifier + ". Available quantizers: " + registry.keySet()
+ );
+ }
+ @SuppressWarnings("unchecked")
+ Quantizer
quantizer = (Quantizer
) supplier.get();
+ return quantizer;
+ }
+
+ /**
+ * Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype.
+ *
+ * @param quantizationType the quantization type
+ * @param sqType the specific quantization subtype
+ * @return a string identifier
+ */
+ private static String createIdentifier(final QuantizationType quantizationType, final ScalarQuantizationType sqType) {
+ return quantizationType.name() + "_" + sqType.name();
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
new file mode 100644
index 0000000000..18077182fc
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationOutput;
+
+/**
+ * The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
+ * It implements the QuantizationOutput interface to handle byte arrays specifically.
+ */
+public class BinaryQuantizationOutput implements QuantizationOutput {
+ private final byte[] quantizedVector;
+
+ /**
+ * Constructs a BinaryQuantizationOutput instance with the specified quantized vector.
+ *
+ * @param quantizedVector the quantized vector represented as a byte array.
+ */
+ public BinaryQuantizationOutput(final byte[] quantizedVector) {
+ if (quantizedVector == null) {
+ throw new IllegalArgumentException("Quantized vector cannot be null");
+ }
+ this.quantizedVector = quantizedVector;
+ }
+
+ @Override
+ public byte[] getQuantizedVector() {
+ return quantizedVector;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
new file mode 100644
index 0000000000..c5c5fd21f6
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
@@ -0,0 +1,20 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationOutput;
+
+/**
+ * The QuantizationOutput interface defines the contract for quantization output data.
+ *
+ * @param The type of the quantized data.
+ */
+public interface QuantizationOutput {
+ /**
+ * Returns the quantized vector.
+ *
+ * @return the quantized data.
+ */
+ T getQuantizedVector();
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
new file mode 100644
index 0000000000..2c982a3064
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+
+import java.io.Serializable;
+
+/**
+ * Interface for quantization parameters.
+ * This interface defines the basic contract for all quantization parameter types.
+ * It provides methods to retrieve the quantization type and a unique type identifier.
+ * Implementations of this interface are expected to provide specific configurations
+ * for various quantization strategies.
+ */
+public interface QuantizationParams extends Serializable {
+
+ /**
+ * Gets the quantization type associated with the parameters.
+ * The quantization type defines the overall strategy or method used
+ * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION.
+ *
+ * @return the {@link QuantizationType} indicating the quantization method.
+ */
+ QuantizationType getQuantizationType();
+
+ /**
+ * Provides a unique identifier for the quantization parameters.
+ * This identifier is typically a combination of the quantization type
+ * and additional specifics, and it serves to distinguish between different
+ * configurations or modes of quantization.
+ *
+ * @return a string representing the unique type identifier.
+ */
+ String getTypeIdentifier();
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
new file mode 100644
index 0000000000..0b6bbc9885
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+
+import java.util.Objects;
+
+/**
+ * The SQParams class represents the parameters specific to scalar quantization (SQ).
+ * This class implements the QuantizationParams interface and includes the type of scalar quantization.
+ */
+public class SQParams implements QuantizationParams {
+ private final ScalarQuantizationType sqType;
+
+ /**
+ * Constructs an SQParams instance with the specified scalar quantization type.
+ *
+ * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT).
+ */
+ public SQParams(final ScalarQuantizationType sqType) {
+ this.sqType = sqType;
+ }
+
+ /**
+ * Returns the quantization type associated with these parameters.
+ *
+ * @return The quantization type, always VALUE_QUANTIZATION for SQParams.
+ */
+ @Override
+ public QuantizationType getQuantizationType() {
+ return QuantizationType.VALUE;
+ }
+
+ /**
+ * Returns the scalar quantization type.
+ *
+ * @return The specific scalar quantization type.
+ */
+ public ScalarQuantizationType getSqType() {
+ return sqType;
+ }
+
+ /**
+ * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type.
+ * This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
+ *
+ * @return A string representing the unique type identifier.
+ */
+ @Override
+ public String getTypeIdentifier() {
+ return getQuantizationType().name() + "_" + sqType.name();
+ }
+
+ /**
+ * Compares this object to the specified object. The result is true if and only if the argument is not null and is
+ * an SQParams object that represents the same scalar quantization type.
+ *
+ * @param o The object to compare this SQParams against.
+ * @return true if the given object represents an SQParams equivalent to this instance, false otherwise.
+ */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ SQParams sqParams = (SQParams) o;
+ return sqType == sqParams.sqType;
+ }
+
+ /**
+ * Returns a hash code value for this SQParams instance.
+ *
+ * @return A hash code value for this SQParams instance.
+ */
+ @Override
+ public int hashCode() {
+ return Objects.hash(sqType);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
new file mode 100644
index 0000000000..acc8c2f009
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+
+import java.io.IOException;
+
+/**
+ * DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
+ * It can be utilized by any quantizer to represent a default state.
+ */
+public class DefaultQuantizationState implements QuantizationState {
+
+ private final QuantizationParams params;
+
+ /**
+ * Constructs a DefaultQuantizationState with the given quantization parameters.
+ *
+ * @param params the quantization parameters.
+ */
+ public DefaultQuantizationState(final QuantizationParams params) {
+ this.params = params;
+ }
+
+ /**
+ * Returns the quantization parameters associated with this state.
+ *
+ * @return the quantization parameters.
+ */
+ @Override
+ public QuantizationParams getQuantizationParams() {
+ return params;
+ }
+
+ /**
+ * Serializes the quantization state to a byte array.
+ *
+ * @return a byte array representing the serialized state.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this, null);
+ }
+
+ /**
+ * Deserializes a DefaultQuantizationState from a byte array.
+ *
+ * @param bytes the byte array containing the serialized state.
+ * @return the deserialized DefaultQuantizationState.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of the serialized object cannot be found.
+ */
+ public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
+ return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(
+ bytes,
+ (parentParams, specificData) -> new DefaultQuantizationState((SQParams) parentParams)
+ );
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
new file mode 100644
index 0000000000..58834dd2c9
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+
+import java.io.IOException;
+
+/**
+ * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization,
+ * including the thresholds used for quantization.
+ */
+public final class MultiBitScalarQuantizationState implements QuantizationState {
+ private final SQParams quantizationParams;
+ private final float[][] thresholds;
+
+ /**
+ * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds.
+ *
+ * @param quantizationParams the scalar quantization parameters.
+ * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array
+ * where each row corresponds to a different bit level.
+ */
+ public MultiBitScalarQuantizationState(final SQParams quantizationParams, final float[][] thresholds) {
+ this.quantizationParams = quantizationParams;
+ this.thresholds = thresholds;
+ }
+
+ @Override
+ public SQParams getQuantizationParams() {
+ return quantizationParams;
+ }
+
+ /**
+ * Returns the thresholds used in the quantization process.
+ *
+ * @return a 2D array of threshold values.
+ */
+ public float[][] getThresholds() {
+ return thresholds;
+ }
+
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this, thresholds);
+ }
+
+ public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
+ return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
+ bytes,
+ (parentParams, thresholds) -> new MultiBitScalarQuantizationState((SQParams) parentParams, (float[][]) thresholds)
+ );
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
new file mode 100644
index 0000000000..9b4bad56a4
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.util.QuantizationStateSerializer;
+
+import java.io.IOException;
+
+/**
+ * OneBitScalarQuantizationState represents the state of one-bit scalar quantization,
+ * including the mean values used for quantization.
+ */
+public final class OneBitScalarQuantizationState implements QuantizationState {
+ private final SQParams quantizationParams;
+ private final float[] meanThresholds;
+
+ /**
+ * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values.
+ *
+ * @param quantizationParams the scalar quantization parameters.
+ * @param mean the mean values for each dimension.
+ */
+ public OneBitScalarQuantizationState(final SQParams quantizationParams, final float[] mean) {
+ this.quantizationParams = quantizationParams;
+ this.meanThresholds = mean;
+ }
+
+ @Override
+ public SQParams getQuantizationParams() {
+ return quantizationParams;
+ }
+
+ /**
+ * Returns the mean values used in the quantization process.
+ *
+ * @return an array of mean values.
+ */
+ public float[] getMeanThresholds() {
+ return meanThresholds;
+ }
+
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this, meanThresholds);
+ }
+
+ public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
+ return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
+ bytes,
+ (parentParams, meanThresholds) -> new OneBitScalarQuantizationState((SQParams) parentParams, (float[]) meanThresholds)
+ );
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
new file mode 100644
index 0000000000..c17ff0641b
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+/**
+ * QuantizationState interface represents the state of a quantization process, including the parameters used.
+ * This interface provides methods for serializing and deserializing the state.
+ */
+public interface QuantizationState extends Serializable {
+ /**
+ * Returns the quantization parameters associated with this state.
+ *
+ * @return the quantization parameters.
+ */
+ QuantizationParams getQuantizationParams();
+
+ /**
+ * Serializes the quantization state to a byte array.
+ *
+ * @return a byte array representing the serialized state.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ byte[] toByteArray() throws IOException;
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
new file mode 100644
index 0000000000..14689ea476
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.requests;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+
+/**
+ * TrainingRequest represents a request for training a quantizer.
+ *
+ * @param the type of vectors to be trained.
+ */
+public abstract class TrainingRequest {
+ private final QuantizationParams params;
+ private final int totalNumberOfVectors;
+ private int[] sampledIndices;
+
+ /**
+ * Constructs a TrainingRequest with the given parameters and total number of vectors.
+ *
+ * @param params the quantization parameters.
+ * @param totalNumberOfVectors the total number of vectors.
+ */
+ protected TrainingRequest(final QuantizationParams params, final int totalNumberOfVectors) {
+ this.params = params;
+ this.totalNumberOfVectors = totalNumberOfVectors;
+ }
+
+ /**
+ * Returns the quantization parameters.
+ *
+ * @return the quantization parameters.
+ */
+ public QuantizationParams getParams() {
+ return params;
+ }
+
+ /**
+ * Returns the total number of vectors.
+ *
+ * @return the total number of vectors.
+ */
+ public int getTotalNumberOfVectors() {
+ return totalNumberOfVectors;
+ }
+
+ /**
+ * Sets the sampled indices for this training request.
+ *
+ * @param sampledIndices the sampled indices.
+ */
+ public void setSampledIndices(int[] sampledIndices) {
+ this.sampledIndices = sampledIndices;
+ }
+
+ /**
+ * Returns the sampled indices for this training request.
+ *
+ * @return the sampled indices.
+ */
+ public int[] getSampledIndices() {
+ return sampledIndices;
+ }
+
+ /**
+ * Returns the vector corresponding to the specified document ID.
+ *
+ * @param docId the document ID.
+ * @return the vector corresponding to the specified document ID.
+ */
+ public abstract T getVectorByDocId(int docId);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
new file mode 100644
index 0000000000..0143a66147
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+import org.opensearch.knn.quantization.util.BitPacker;
+import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension.
+ * It supports multiple bits per coordinate, allowing for finer granularity in quantization.
+ */
+public class MultiBitScalarQuantizer implements Quantizer {
+ private final int bitsPerCoordinate; // Number of bits used to quantize each dimension
+ private final int samplingSize; // Sampling size for training
+ private final Sampler sampler; // Sampler for training
+ private static final boolean IS_TRAINING_REQUIRED = true;
+ // Currently Lucene has sampling size as
+ // 25000 for segment level training , Keeping same
+ // to having consistent, Will revisit
+ // if this requires change
+ private static final int DEFAULT_SAMPLE_SIZE = 25000;
+
+ /**
+ * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate.
+ *
+ * @param bitsPerCoordinate the number of bits used per coordinate for quantization.
+ */
+ public MultiBitScalarQuantizer(final int bitsPerCoordinate) {
+ this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR));
+ }
+
+ /**
+ * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate and sampling size.
+ *
+ * @param bitsPerCoordinate the number of bits used per coordinate for quantization.
+ * @param samplingSize the number of samples to use for training.
+ * @param sampler the sampler to use for training.
+ */
+ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSize, final Sampler sampler) {
+ if (bitsPerCoordinate < 2) {
+ throw new IllegalArgumentException("bitsPerCoordinate must be greater than or equal to 2 for multibit quantizer.");
+ }
+ this.bitsPerCoordinate = bitsPerCoordinate;
+ this.samplingSize = samplingSize;
+ this.sampler = sampler;
+ }
+
+ /**
+ * Trains the quantizer based on the provided training request, which should be of type SamplingTrainingRequest.
+ * The training process calculates the mean and standard deviation for each dimension and then determines
+ * threshold values for quantization based on these statistics.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @return a MultiBitScalarQuantizationState containing the computed thresholds.
+ */
+ @Override
+ public QuantizationState train(final TrainingRequest trainingRequest) {
+ SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest);
+ int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+
+ int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length;
+ float[] meanArray = new float[dimension];
+ float[] stdDevArray = new float[dimension];
+ // Calculate sum, mean, and standard deviation in one pass
+ QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray);
+ float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension);
+ return new MultiBitScalarQuantizationState(params, thresholds);
+ }
+
+ /**
+ * Quantizes the provided vector using the provided quantization state, producing a quantized output.
+ * The vector is quantized based on the thresholds in the quantization state.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing threshold information.
+ * @return a BinaryQuantizationOutput containing the quantized data.
+ */
+ @Override
+ public QuantizationOutput quantize(final float[] vector, final QuantizationState state) {
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector to quantize must not be null.");
+ }
+ validateState(state);
+ MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state;
+ float[][] thresholds = multiBitState.getThresholds();
+ if (thresholds == null || thresholds[0].length != vector.length) {
+ throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
+ }
+
+ List bitArrays = new ArrayList<>();
+ for (int i = 0; i < bitsPerCoordinate; i++) {
+ byte[] bitArray = new byte[vector.length];
+ for (int j = 0; j < vector.length; j++) {
+ bitArray[j] = (byte) (vector[j] > thresholds[i][j] ? 1 : 0);
+ }
+ bitArrays.add(bitArray);
+ }
+
+ return new BinaryQuantizationOutput(BitPacker.packBits(bitArrays));
+ }
+
+ /**
+ * Calculates the thresholds for quantization based on mean and standard deviation.
+ *
+ * @param meanArray the mean for each dimension.
+ * @param stdDevArray the standard deviation for each dimension.
+ * @param dimension the number of dimensions in the vectors.
+ * @return the thresholds for quantization.
+ */
+ private float[][] calculateThresholds(final float[] meanArray, final float[] stdDevArray, final int dimension) {
+ float[][] thresholds = new float[bitsPerCoordinate][dimension];
+ float coef = bitsPerCoordinate + 1;
+ for (int i = 0; i < bitsPerCoordinate; i++) {
+ float iCoef = -1 + 2 * (i + 1) / coef;
+ for (int j = 0; j < dimension; j++) {
+ thresholds[i][j] = meanArray[j] + iCoef * stdDevArray[j];
+ }
+ }
+ return thresholds;
+ }
+
+ /**
+ * Validates the quantization state to ensure it is of the expected type.
+ *
+ * @param state the quantization state to validate.
+ * @throws IllegalArgumentException if the state is not an instance of MultiBitScalarQuantizationState.
+ */
+ private void validateState(final QuantizationState state) {
+ if (!(state instanceof MultiBitScalarQuantizationState)) {
+ throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState.");
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
new file mode 100644
index 0000000000..2eaa07ce03
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
@@ -0,0 +1,107 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+import org.opensearch.knn.quantization.util.BitPacker;
+import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+import java.util.Collections;
+
+/**
+ * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension.
+ * It computes the mean of each dimension during training and then uses these means as thresholds
+ * for quantizing the vectors.
+ */
+public class OneBitScalarQuantizer implements Quantizer {
+ private final int samplingSize; // Sampling size for training
+ private static final boolean IS_TRAINING_REQUIRED = true;
+ private final Sampler sampler; // Sampler for training
+ // Currently Lucene has sampling size as
+ // 25000 for segment level training , Keeping same
+ // to having consistent, Will revisit
+ // if this requires change
+ private static final int DEFAULT_SAMPLE_SIZE = 25000;
+
+ /**
+ * Constructs a OneBitScalarQuantizer with a default sampling size of 25000.
+ */
+ public OneBitScalarQuantizer() {
+ this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR));
+ }
+
+ /**
+ * Constructs a OneBitScalarQuantizer with a specified sampling size.
+ *
+ * @param samplingSize the number of samples to use for training.
+ */
+ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) {
+
+ this.samplingSize = samplingSize;
+ this.sampler = sampler;
+ ;
+ }
+
+ /**
+ * Trains the quantizer by calculating the mean of each dimension from the sampled vectors.
+ * These means are used as thresholds in the quantization process.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @return a OneBitScalarQuantizationState containing the calculated means.
+ */
+ @Override
+ public QuantizationState train(final TrainingRequest trainingRequest) {
+ SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest);
+ int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+ float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices);
+ return new OneBitScalarQuantizationState(params, mean);
+ }
+
+ /**
+ * Quantizes the provided vector using the given quantization state.
+ * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing the means for each dimension.
+ * @return a BinaryQuantizationOutput containing the quantized data.
+ */
+ @Override
+ public QuantizationOutput quantize(final float[] vector, final QuantizationState state) {
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector to quantize must not be null.");
+ }
+ validateState(state);
+ OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state;
+ float[] thresholds = binaryState.getMeanThresholds();
+ if (thresholds == null || thresholds.length != vector.length) {
+ throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
+ }
+ byte[] quantizedVector = new byte[vector.length];
+ for (int i = 0; i < vector.length; i++) {
+ quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0);
+ }
+ return new BinaryQuantizationOutput(BitPacker.packBits(Collections.singletonList(quantizedVector)));
+ }
+
+ /**
+ * Validates the quantization state to ensure it is of the expected type.
+ *
+ * @param state the quantization state to validate.
+ * @throws IllegalArgumentException if the state is not an instance of OneBitScalarQuantizationState.
+ */
+ private void validateState(final QuantizationState state) {
+ if (!(state instanceof OneBitScalarQuantizationState)) {
+ throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState.");
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
new file mode 100644
index 0000000000..8231a8aa27
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+/**
+ * The Quantizer interface defines the methods required for training and quantizing vectors
+ * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks.
+ * It supports training to determine quantization parameters and quantizing data vectors
+ * based on these parameters.
+ *
+ * @param The type of the vector or data to be quantized.
+ * @param The type of the quantized output, typically a compressed or encoded representation.
+ */
+public interface Quantizer {
+
+ /**
+ * Trains the quantizer based on the provided training request. The training process typically
+ * involves learning parameters that can be used to quantize vectors.
+ *
+ * @param trainingRequest the request containing data and parameters for training.
+ * @return a QuantizationState containing the learned parameters.
+ */
+ QuantizationState train(TrainingRequest trainingRequest);
+
+ /**
+ * Quantizes the provided vector using the specified quantization state.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing parameters for quantization.
+ * @return a QuantizationOutput containing the quantized representation of the vector.
+ */
+ QuantizationOutput quantize(T vector, QuantizationState state);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
new file mode 100644
index 0000000000..da5327def2
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import java.util.Arrays;
+import java.util.Random;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.stream.IntStream;
+
+/**
+ * ReservoirSampler implements the Sampler interface and provides a method for sampling
+ * a specified number of indices from a total number of vectors using the reservoir sampling algorithm.
+ * This algorithm is particularly useful for randomly sampling a subset of data from a larger set
+ * when the total size of the dataset is unknown or very large.
+ */
+final class ReservoirSampler implements Sampler {
+
+ private final Random random;
+
+ /**
+ * Constructs a ReservoirSampler with a new Random instance.
+ */
+ public ReservoirSampler() {
+ this(ThreadLocalRandom.current());
+ }
+
+ /**
+ * Constructs a ReservoirSampler with a specified random seed for reproducibility.
+ *
+ * @param seed the seed for the random number generator.
+ */
+ public ReservoirSampler(final long seed) {
+ this(new Random(seed));
+ }
+
+ /**
+ * Constructs a ReservoirSampler with a specified Random instance.
+ *
+ * @param random the Random instance for generating random numbers.
+ */
+ public ReservoirSampler(final Random random) {
+ this.random = random;
+ }
+
+ /**
+ * Samples indices from the range [0, totalNumberOfVectors).
+ * If the total number of vectors is less than or equal to the sample size, it returns all indices.
+ * Otherwise, it uses the reservoir sampling algorithm to select a random subset.
+ *
+ * @param totalNumberOfVectors the total number of vectors to sample from.
+ * @param sampleSize the number of indices to sample.
+ * @return an array of sampled indices.
+ */
+ @Override
+ public int[] sample(final int totalNumberOfVectors, final int sampleSize) {
+ if (totalNumberOfVectors <= sampleSize) {
+ return IntStream.range(0, totalNumberOfVectors).toArray();
+ }
+ return reservoirSampleIndices(totalNumberOfVectors, sampleSize);
+ }
+
+ /**
+ * Applies the reservoir sampling algorithm to select a random sample of indices.
+ * This method ensures that each index in the range [0, numVectors) has an equal probability
+ * of being included in the sample.
+ *
+ * @param numVectors the total number of vectors.
+ * @param sampleSize the number of indices to sample.
+ * @return an array of sampled indices.
+ */
+ private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) {
+ int[] indices = IntStream.range(0, sampleSize).toArray();
+ for (int i = sampleSize; i < numVectors; i++) {
+ int j = random.nextInt(i + 1);
+ if (j < sampleSize) {
+ indices[j] = i;
+ }
+ }
+ Arrays.sort(indices);
+ return indices;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
new file mode 100644
index 0000000000..9021073b4e
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
@@ -0,0 +1,10 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+public interface Sampler {
+ int[] sample(int totalNumberOfVectors, int sampleSize);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
new file mode 100644
index 0000000000..be228fe6f9
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+/**
+ * SamplingFactory is a factory class for creating instances of Sampler.
+ * It uses the factory design pattern to encapsulate the creation logic for different types of samplers.
+ */
+public final class SamplingFactory {
+
+ /**
+ * Private constructor to prevent instantiation of this class.
+ * The class is not meant to be instantiated, as it provides static methods only.
+ */
+ private SamplingFactory() {
+
+ }
+
+ /**
+ * SamplerType is an enumeration of the different types of samplers that can be created by the factory.
+ */
+ public enum SamplerType {
+ RESERVOIR, // Represents a reservoir sampling strategy
+ // Add more enum values here for additional sampler types
+ }
+
+ /**
+ * Creates and returns a Sampler instance based on the specified SamplerType.
+ *
+ * @param samplerType the type of sampler to create.
+ * @return a Sampler instance.
+ * @throws IllegalArgumentException if the sampler type is not supported.
+ */
+ public static Sampler getSampler(final SamplerType samplerType) {
+ switch (samplerType) {
+ case RESERVOIR:
+ return new ReservoirSampler();
+ // Add more cases for different samplers here
+ default:
+ throw new IllegalArgumentException("Unsupported sampler type: " + samplerType);
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java
new file mode 100644
index 0000000000..5d99a892fb
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java
@@ -0,0 +1,59 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import lombok.experimental.UtilityClass;
+
+import java.util.List;
+
+/**
+ * Utility class for bit packing operations.
+ * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission.
+ */
+@UtilityClass
+public class BitPacker {
+
+ /**
+ * Packs the list of bit arrays into a single byte array.
+ * Each byte in the resulting array contains up to 8 bits from the bit arrays, packed from left to right.
+ *
+ * @param bitArrays the list of bit arrays to be packed. Each bit array should contain only 0s and 1s.
+ * @return a byte array containing the packed bits.
+ * @throws IllegalArgumentException if the bitArrays list is empty, if any bit array is null, or if bit arrays have inconsistent lengths.
+ */
+ public static byte[] packBits(List bitArrays) {
+ if (bitArrays.isEmpty()) {
+ throw new IllegalArgumentException("The list of bit arrays cannot be empty.");
+ }
+
+ int bitArrayLength = bitArrays.get(0).length;
+ int bitLength = bitArrays.size() * bitArrayLength;
+ int byteLength = (bitLength + 7) / 8;
+ byte[] packedArray = new byte[byteLength];
+
+ int bitPosition = 0;
+ for (byte[] bitArray : bitArrays) {
+ if (bitArray == null) {
+ throw new IllegalArgumentException("Bit array cannot be null.");
+ }
+ if (bitArray.length != bitArrayLength) {
+ throw new IllegalArgumentException("All bit arrays must have the same length.");
+ }
+
+ for (byte bit : bitArray) {
+ int byteIndex = bitPosition / 8;
+ int bitIndex = 7 - (bitPosition % 8);
+ if (bit == 1) {
+ packedArray[byteIndex] |= (1 << bitIndex);
+ }
+ bitPosition++;
+ }
+ }
+
+ return packedArray;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java
new file mode 100644
index 0000000000..89b3b67bd8
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import lombok.experimental.UtilityClass;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.io.IOException;
+import java.io.ByteArrayInputStream;
+import java.io.ObjectInputStream;
+
+/**
+ * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing
+ * QuantizationState objects along with their specific data.
+ */
+@UtilityClass
+public class QuantizationStateSerializer {
+
+ /**
+ * A functional interface for deserializing specific data associated with a QuantizationState.
+ */
+ @FunctionalInterface
+ public interface SerializableDeserializer {
+ QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData);
+ }
+
+ /**
+ * Serializes the QuantizationState and specific data into a byte array.
+ *
+ * @param state The QuantizationState to serialize.
+ * @param specificData The specific data related to the state, to be serialized.
+ * @return A byte array representing the serialized state and specific data.
+ * @throws IOException If an I/O error occurs during serialization.
+ */
+ public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException {
+ byte[] parentBytes = serializeParentParams(state.getQuantizationParams());
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ out.writeInt(parentBytes.length); // Write the length of the parent bytes
+ out.write(parentBytes); // Write the parent bytes
+ out.writeObject(specificData); // Write the specific data
+ out.flush();
+ return bos.toByteArray();
+ }
+ }
+
+ /**
+ * Deserializes a QuantizationState and its specific data from a byte array.
+ *
+ * @param bytes The byte array containing the serialized data.
+ * @param specificDataDeserializer The deserializer for the specific data associated with the state.
+ * @return The deserialized QuantizationState including its specific data.
+ * @throws IOException If an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException If the class of the serialized object cannot be found.
+ */
+ public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer) throws IOException,
+ ClassNotFoundException {
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) {
+ int parentLength = in.readInt();
+ // Read the length of the parent bytes
+ byte[] parentBytes = new byte[parentLength];
+ in.readFully(parentBytes); // Read the parent bytes
+ QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params
+ Serializable specificData = (Serializable) in.readObject(); // Read the specific data
+ return specificDataDeserializer.deserialize(parentParams, specificData);
+ }
+ }
+
+ /**
+ * Serializes the parent parameters of the QuantizationState into a byte array.
+ *
+ * @param params The QuantizationParams to serialize.
+ * @return A byte array representing the serialized parent parameters.
+ * @throws IOException If an I/O error occurs during serialization.
+ */
+ private static byte[] serializeParentParams(QuantizationParams params) throws IOException {
+ try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) {
+ out.writeObject(params);
+ out.flush();
+ return bos.toByteArray();
+ }
+ }
+
+ /**
+ * Deserializes the parent parameters of the QuantizationState from a byte array.
+ *
+ * @param bytes The byte array containing the serialized parent parameters.
+ * @return The deserialized QuantizationParams.
+ * @throws IOException If an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException If the class of the serialized object cannot be found.
+ */
+ private static QuantizationParams deserializeParentParams(byte[] bytes) throws IOException, ClassNotFoundException {
+ try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) {
+ return (QuantizationParams) in.readObject();
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java
new file mode 100644
index 0000000000..adc4e34c43
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import lombok.experimental.UtilityClass;
+
+/**
+ * Utility class providing common methods for quantizer operations, such as parameter validation and
+ * extraction. This class is designed to be used with various quantizer implementations that require
+ * consistent handling of training requests and sampled indices.
+ */
+@UtilityClass
+public class QuantizerHelper {
+
+ /**
+ * Validates the provided training request to ensure it contains non-null quantization parameters.
+ * Extracts and returns the SQParams from the training request.
+ *
+ * @param trainingRequest the training request to validate and extract parameters from.
+ * @return the extracted SQParams.
+ * @throws IllegalArgumentException if the SQParams are null.
+ */
+ public static SQParams validateAndExtractParams(TrainingRequest> trainingRequest) {
+ QuantizationParams params = trainingRequest.getParams();
+ if (params == null || !(params instanceof SQParams)) {
+ throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams.");
+ }
+ return (SQParams) params;
+ }
+
+ /**
+ * Calculates the mean vector from a set of sampled vectors.
+ *
+ * This method takes a {@link TrainingRequest} object and an array of sampled indices,
+ * retrieves the vectors corresponding to these indices, and calculates the mean vector.
+ * Each element of the mean vector is computed as the average of the corresponding elements
+ * of the sampled vectors.
+ *
+ * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices.
+ * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation.
+ * @return A float array representing the mean vector of the sampled vectors.
+ * @throws IllegalArgumentException If any of the vectors at the sampled indices are null.
+ * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors.
+ */
+ public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) {
+ int totalSamples = sampledIndices.length;
+ float[] mean = null;
+ for (int index : sampledIndices) {
+ float[] vector = samplingRequest.getVectorByDocId(index);
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector at sampled index " + index + " is null.");
+ }
+ if (mean == null) {
+ mean = new float[vector.length];
+ }
+ for (int j = 0; j < vector.length; j++) {
+ mean[j] += vector[j];
+ }
+ }
+ if (mean == null) {
+ throw new IllegalStateException("Mean array should not be null after processing vectors.");
+ }
+ for (int j = 0; j < mean.length; j++) {
+ mean[j] /= totalSamples;
+ }
+ return mean;
+ }
+
+ /**
+ * Calculates the sum, sum of squares, mean, and standard deviation for each dimension in a single pass.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @param sampledIndices the indices of the sampled vectors.
+ * @param meanArray the array to store the sum and then the mean of each dimension.
+ * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension.
+ */
+ public static void calculateSumMeanAndStdDev(
+ TrainingRequest trainingRequest,
+ int[] sampledIndices,
+ float[] meanArray,
+ float[] stdDevArray
+ ) {
+ int totalSamples = sampledIndices.length;
+ int dimension = meanArray.length;
+
+ // Single pass to calculate sum and sum of squares
+ for (int index : sampledIndices) {
+ float[] vector = trainingRequest.getVectorByDocId(index);
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector at sampled index " + index + " is null.");
+ }
+ for (int j = 0; j < dimension; j++) {
+ meanArray[j] += vector[j];
+ stdDevArray[j] += vector[j] * vector[j];
+ }
+ }
+
+ // Calculate mean and standard deviation in one pass
+ for (int j = 0; j < dimension; j++) {
+ meanArray[j] = meanArray[j] / totalSamples;
+ stdDevArray[j] = (float) Math.sqrt((stdDevArray[j] / totalSamples) - (meanArray[j] * meanArray[j]));
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
new file mode 100644
index 0000000000..ab6828cbf6
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java
@@ -0,0 +1,21 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class QuantizationTypeTests extends KNNTestCase {
+
+ public void testQuantizationTypeValues() {
+ QuantizationType[] expectedValues = { QuantizationType.SPACE, QuantizationType.VALUE };
+ assertArrayEquals(expectedValues, QuantizationType.values());
+ }
+
+ public void testQuantizationTypeValueOf() {
+ assertEquals(QuantizationType.SPACE, QuantizationType.valueOf("SPACE_QUANTIZATION"));
+ assertEquals(QuantizationType.VALUE, QuantizationType.valueOf("VALUE_QUANTIZATION"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
new file mode 100644
index 0000000000..98bc38e6e6
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class SQTypesTests extends KNNTestCase {
+ public void testSQTypesValues() {
+ ScalarQuantizationType[] expectedValues = {
+ ScalarQuantizationType.FP16,
+ ScalarQuantizationType.FP8,
+ ScalarQuantizationType.INT8,
+ ScalarQuantizationType.INT6,
+ ScalarQuantizationType.INT4,
+ ScalarQuantizationType.ONE_BIT,
+ ScalarQuantizationType.TWO_BIT,
+ ScalarQuantizationType.FOUR_BIT,
+ ScalarQuantizationType.UNSUPPORTED_TYPE };
+ assertArrayEquals(expectedValues, ScalarQuantizationType.values());
+ }
+
+ public void testSQTypesValueOf() {
+ assertEquals(ScalarQuantizationType.FP16, ScalarQuantizationType.valueOf("FP16"));
+ assertEquals(ScalarQuantizationType.INT8, ScalarQuantizationType.valueOf("INT8"));
+ assertEquals(ScalarQuantizationType.INT6, ScalarQuantizationType.valueOf("INT6"));
+ assertEquals(ScalarQuantizationType.INT4, ScalarQuantizationType.valueOf("INT4"));
+ assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT"));
+ assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT"));
+ assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT"));
+ assertEquals(ScalarQuantizationType.UNSUPPORTED_TYPE, ScalarQuantizationType.valueOf("UNSUPPORTED_TYPE"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
new file mode 100644
index 0000000000..72ebd751fa
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class ValueQuantizationTypeTests extends KNNTestCase {
+ public void testValueQuantizationTypeValues() {
+ ValueQuantizationType[] expectedValues = { ValueQuantizationType.SCALAR };
+ assertArrayEquals(expectedValues, ValueQuantizationType.values());
+ }
+
+ public void testValueQuantizationTypeValueOf() {
+ assertEquals(ValueQuantizationType.SCALAR, ValueQuantizationType.valueOf("SQ"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
new file mode 100644
index 0000000000..42fc18eba7
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.junit.Before;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.lang.reflect.Field;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class QuantizerFactoryTests extends KNNTestCase {
+
+ @Before
+ public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException {
+ Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered");
+ isRegisteredField.setAccessible(true);
+ AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null);
+ isRegistered.set(false);
+ }
+
+ public void test_Lazy_Registration() {
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ assertFalse(isRegisteredFieldAccessible());
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof OneBitScalarQuantizer);
+ assertTrue(isRegisteredFieldAccessible());
+ }
+
+ public void testGetQuantizer_withOneBitSQParams() {
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof OneBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withTwoBitSQParams() {
+ SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof MultiBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withFourBitSQParams() {
+ SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ assertTrue(quantizer instanceof MultiBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withUnsupportedType() {
+ SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE);
+ try {
+ QuantizerFactory.getQuantizer(params);
+ fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ assertTrue(e.getMessage().contains("No quantizer registered for type identifier"));
+ }
+ }
+
+ public void testGetQuantizer_withNullParams() {
+ try {
+ QuantizerFactory.getQuantizer(null);
+ fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ assertEquals("Quantization parameters must not be null.", e.getMessage());
+ }
+ }
+
+ public void testConcurrentRegistration() throws InterruptedException {
+ Runnable task = () -> {
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ QuantizerFactory.getQuantizer(params);
+ };
+
+ Thread thread1 = new Thread(task);
+ Thread thread2 = new Thread(task);
+ thread1.start();
+ thread2.start();
+ thread1.join();
+ thread2.join();
+ assertTrue(isRegisteredFieldAccessible());
+ }
+
+ private boolean isRegisteredFieldAccessible() {
+ try {
+ Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered");
+ isRegisteredField.setAccessible(true);
+ AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null);
+ return isRegistered.get();
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ fail("Failed to access isRegistered field.");
+ return false;
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
new file mode 100644
index 0000000000..7f53dae8c2
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.junit.BeforeClass;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.QuantizationType;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+public class QuantizerRegistryTests extends KNNTestCase {
+
+ @BeforeClass
+ public static void setup() {
+ // Register the quantizers for testing with enums
+ QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new);
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE,
+ ScalarQuantizationType.TWO_BIT,
+ () -> new MultiBitScalarQuantizer(2)
+ );
+ QuantizerRegistry.register(
+ SQParams.class,
+ QuantizationType.VALUE,
+ ScalarQuantizationType.FOUR_BIT,
+ () -> new MultiBitScalarQuantizer(4)
+ );
+ }
+
+ public void testRegisterAndGetQuantizer() {
+ // Test for OneBitScalarQuantizer
+ SQParams oneBitParams = new SQParams(ScalarQuantizationType.ONE_BIT);
+ Quantizer, ?> oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer);
+
+ // Test for MultiBitScalarQuantizer (2-bit)
+ SQParams twoBitParams = new SQParams(ScalarQuantizationType.TWO_BIT);
+ Quantizer, ?> twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer);
+
+ // Test for MultiBitScalarQuantizer (4-bit)
+ SQParams fourBitParams = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ Quantizer, ?> fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer);
+ }
+
+ public void testGetQuantizer_withUnsupportedTypeIdentifier() {
+ // Create SQParams with an unsupported type identifier
+ SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered
+
+ // Expect IllegalArgumentException when requesting a quantizer with unsupported params
+ IllegalArgumentException exception = assertThrows(
+ IllegalArgumentException.class,
+ () -> { QuantizerRegistry.getQuantizer(params); }
+ );
+
+ assertTrue(exception.getMessage().contains("No quantizer registered for type identifier"));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
new file mode 100644
index 0000000000..ebe6bf6bd9
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizationState;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+
+import java.io.IOException;
+
+public class QuantizationStateSerializerTests extends KNNTestCase {
+
+ public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ float[] mean = new float[] { 0.1f, 0.2f, 0.3f };
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+
+ byte[] serialized = state.toByteArray();
+ OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized);
+
+ assertArrayEquals(mean, deserialized.getMeanThresholds(), 0.0f);
+ assertEquals(params, deserialized.getQuantizationParams());
+ }
+
+ public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } };
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ byte[] serialized = state.toByteArray();
+ MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized);
+
+ assertArrayEquals(thresholds, deserialized.getThresholds());
+ assertEquals(params, deserialized.getQuantizationParams());
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
new file mode 100644
index 0000000000..54e304732b
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.quantizationState;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+
+import java.io.IOException;
+
+public class QuantizationStateTests extends KNNTestCase {
+
+ public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ float[] mean = { 1.0f, 2.0f, 3.0f };
+
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+
+ byte[] serializedState = state.toByteArray();
+ OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState);
+ float delta = 0.0001f;
+
+ assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta);
+ assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ }
+
+ public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } };
+
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ byte[] serializedState = state.toByteArray();
+ MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState);
+ float delta = 0.0001f;
+
+ for (int i = 0; i < thresholds.length; i++) {
+ assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i], delta);
+ }
+ assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ }
+
+ public void testDefaultQuantizationStateSerialization() throws IOException, ClassNotFoundException {
+ SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE);
+
+ DefaultQuantizationState state = new DefaultQuantizationState(params);
+
+ byte[] serializedState = state.toByteArray();
+ DefaultQuantizationState deserializedState = DefaultQuantizationState.fromByteArray(serializedState);
+
+ assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType());
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
new file mode 100644
index 0000000000..46ac2f2a28
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
@@ -0,0 +1,135 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+public class MultiBitScalarQuantizerTests extends KNNTestCase {
+
+ public void testTrain_twoBit() {
+ float[][] vectors = {
+ { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
+ { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ int[] sampledIndices = { 0, 1, 2 };
+ SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ TrainingRequest request = new MockTrainingRequest(params, vectors);
+ request.setSampledIndices(sampledIndices);
+ QuantizationState state = twoBitQuantizer.train(request);
+
+ assertTrue(state instanceof MultiBitScalarQuantizationState);
+ MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state;
+ assertNotNull(mbState.getThresholds());
+ assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds
+ }
+
+ public void testTrain_fourBit() {
+ MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
+ float[][] vectors = {
+ { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
+ { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
+ int[] sampledIndices = { 0, 1, 2 };
+ SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ TrainingRequest request = new MockTrainingRequest(params, vectors);
+ request.setSampledIndices(sampledIndices);
+ QuantizationState state = fourBitQuantizer.train(request);
+
+ assertTrue(state instanceof MultiBitScalarQuantizationState);
+ MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state;
+ assertNotNull(mbState.getThresholds());
+ assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds
+ }
+
+ public void testQuantize_twoBit() {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
+ float[][] thresholds = { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
+ SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT);
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ QuantizationOutput output = twoBitQuantizer.quantize(vector, state);
+ assertNotNull(output.getQuantizedVector());
+ assertEquals(2, output.getQuantizedVector().length);
+ }
+
+ public void testQuantize_fourBit() {
+ MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
+ float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
+ float[][] thresholds = {
+ { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f },
+ { 1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f },
+ { 1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f } };
+ SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT);
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ QuantizationOutput output = fourBitQuantizer.quantize(vector, state);
+ assertEquals(4, output.getQuantizedVector().length);
+ assertNotNull(output.getQuantizedVector());
+ }
+
+ public void testQuantize_withNullVector() {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> twoBitQuantizer.quantize(
+ null,
+ new MultiBitScalarQuantizationState(new SQParams(ScalarQuantizationType.TWO_BIT), new float[2][8])
+ )
+ );
+ }
+
+ public void testQuantize_withInvalidState() {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
+ QuantizationState invalidState = new MockInvalidQuantizationState();
+ expectThrows(IllegalArgumentException.class, () -> twoBitQuantizer.quantize(vector, invalidState));
+ }
+
+ public void testQuantize_withDefaultQuantizationState() {
+ MultiBitScalarQuantizer quantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
+ DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT));
+
+ expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state));
+ }
+
+ // Mock classes for testing
+ private static class MockTrainingRequest extends TrainingRequest {
+ private final float[][] vectors;
+
+ public MockTrainingRequest(SQParams params, float[][] vectors) {
+ super(params, vectors.length);
+ this.vectors = vectors;
+ }
+
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ }
+
+ private static class MockInvalidQuantizationState implements QuantizationState {
+ @Override
+ public SQParams getQuantizationParams() {
+ return new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return new byte[0];
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
new file mode 100644
index 0000000000..9a37ca7995
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
@@ -0,0 +1,128 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
+import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+import org.opensearch.knn.quantization.util.QuantizerHelper;
+
+public class OneBitScalarQuantizerTests extends KNNTestCase {
+
+ public void testTrain_withTrainingRequired() {
+ float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } };
+
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) {
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ QuantizationState state = quantizer.train(originalRequest);
+
+ assertTrue(state instanceof OneBitScalarQuantizationState);
+ float[] mean = ((OneBitScalarQuantizationState) state).getMeanThresholds();
+ assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f);
+ }
+
+ public void testQuantize_withState() {
+ float[] vector = { 3.0f, 6.0f, 9.0f };
+ float[] thresholds = { 4.0f, 5.0f, 6.0f };
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds);
+
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ QuantizationOutput output = quantizer.quantize(vector, state);
+
+ assertNotNull(output);
+ byte[] expectedPackedBits = new byte[] { 0b01100000 }; // or 96 in decimal
+ assertArrayEquals(expectedPackedBits, output.getQuantizedVector());
+ }
+
+ public void testQuantize_withDefaultQuantizationState() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = { 3.0f, 6.0f, 9.0f };
+ DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT));
+
+ expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state));
+ }
+
+ public void testQuantize_withNullVector() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
+ new SQParams(ScalarQuantizationType.ONE_BIT),
+ new float[] { 0.0f }
+ );
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state));
+ }
+
+ public void testQuantize_withInvalidState() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = { 1.0f, 2.0f, 3.0f };
+ QuantizationState invalidState = new QuantizationState() {
+ @Override
+ public SQParams getQuantizationParams() {
+ return new SQParams(ScalarQuantizationType.ONE_BIT);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return new byte[0];
+ }
+ };
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState));
+ }
+
+ public void testQuantize_withMismatchedDimensions() {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = { 1.0f, 2.0f, 3.0f };
+ float[] thresholds = { 4.0f, 5.0f };
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds);
+
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state));
+ }
+
+ public void testCalculateMean() {
+ float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } };
+
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) {
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(vectors.length, 3);
+ float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices);
+ assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f);
+ }
+
+ public void testCalculateMean_withNullVector() {
+ float[][] vectors = { { 1.0f, 2.0f, 3.0f }, null, { 7.0f, 8.0f, 9.0f } };
+
+ SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) {
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(vectors.length, 3);
+ expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
new file mode 100644
index 0000000000..4d33452890
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
@@ -0,0 +1,82 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import org.opensearch.knn.KNNTestCase;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+public class ReservoirSamplerTests extends KNNTestCase {
+
+ public void testSampleLessThanSampleSize() {
+ ReservoirSampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 5;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray();
+ assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
+ }
+
+ public void testSampleEqualToSampleSize() {
+ ReservoirSampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 10;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray();
+ assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
+ }
+
+ public void testSampleGreaterThanSampleSize() {
+ ReservoirSampler sampler = new ReservoirSampler(12345); // Fixed seed for reproducibility
+ int totalNumberOfVectors = 100;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(sampleSize, sampledIndices.length);
+ assertTrue(Arrays.stream(sampledIndices).allMatch(i -> i >= 0 && i < totalNumberOfVectors));
+ }
+
+ public void testSampleReproducibility() {
+ long seed = 12345L;
+ ReservoirSampler sampler1 = new ReservoirSampler(seed);
+ ReservoirSampler sampler2 = new ReservoirSampler(seed);
+ int totalNumberOfVectors = 100;
+ int sampleSize = 10;
+
+ int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize);
+ int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize);
+
+ assertArrayEquals(sampledIndices1, sampledIndices2);
+ }
+
+ public void testSampleRandomness() {
+ ReservoirSampler sampler1 = new ReservoirSampler();
+ ReservoirSampler sampler2 = new ReservoirSampler();
+ int totalNumberOfVectors = 100;
+ int sampleSize = 10;
+
+ int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize);
+ int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize);
+
+ assertNotEquals(Arrays.toString(sampledIndices1), Arrays.toString(sampledIndices2));
+ }
+
+ public void testEdgeCaseZeroVectors() {
+ ReservoirSampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 0;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(0, sampledIndices.length);
+ }
+
+ public void testEdgeCaseZeroSampleSize() {
+ ReservoirSampler sampler = new ReservoirSampler();
+ int totalNumberOfVectors = 10;
+ int sampleSize = 0;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals(0, sampledIndices.length);
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
new file mode 100644
index 0000000000..ca72c1c5e5
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class SamplingFactoryTests extends KNNTestCase {
+ public void testGetSampler_withReservoir() {
+ Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
+ assertTrue(sampler instanceof ReservoirSampler);
+ }
+
+ public void testGetSampler_withUnsupportedType() {
+ expectThrows(NullPointerException.class, () -> SamplingFactory.getSampler(null)); // This should throw an exception
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java
new file mode 100644
index 0000000000..c91c7177b1
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java
@@ -0,0 +1,66 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.util;
+
+import org.opensearch.knn.KNNTestCase;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class BitPackingUtilsTests extends KNNTestCase {
+
+ public void testPackBits() {
+ List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, new byte[] { 1, 0, 1, 0, 0, 1, 0, 0 });
+
+ byte[] expectedPackedArray = new byte[] { (byte) 0b01011011, (byte) 0b10100100 };
+ byte[] packedArray = BitPacker.packBits(bitArrays);
+
+ assertArrayEquals(expectedPackedArray, packedArray);
+ }
+
+ public void testPackBitsEmptyList() {
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(Arrays.asList()); });
+ assertEquals("The list of bit arrays cannot be empty.", exception.getMessage());
+ }
+
+ public void testPackBitsNullBitArray() {
+ List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, null);
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); });
+ assertEquals("Bit array cannot be null.", exception.getMessage());
+ }
+
+ public void testPackBitsInconsistentLength() {
+ List bitArrays = Arrays.asList(new byte[] { 0, 1, 0, 1, 1, 0, 1, 1 }, new byte[] { 1, 0, 1 });
+
+ IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); });
+ assertEquals("All bit arrays must have the same length.", exception.getMessage());
+ }
+
+ public void testPackBitsEdgeCaseSingleBitArray() {
+ List bitArrays = Arrays.asList(new byte[] { 1 });
+
+ byte[] expectedPackedArray = new byte[] { (byte) 0b10000000 };
+ byte[] packedArray = BitPacker.packBits(bitArrays);
+
+ assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray);
+ }
+
+ public void testPackBitsEdgeCaseSingleBit() {
+ List bitArrays = Arrays.asList(new byte[] { 1, 0, 1, 0, 1, 0, 1, 0 }, new byte[] { 1, 1, 1, 1, 1, 1, 1, 1 });
+
+ byte[] expectedPackedArray = new byte[] { (byte) 0b10101010, (byte) 0b11111111 };
+ byte[] packedArray = BitPacker.packBits(bitArrays);
+
+ assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray);
+ }
+
+ public void testPackBits_emptyArray() {
+ List bitArrays = Arrays.asList();
+ expectThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(bitArrays); });
+ ;
+ }
+}