diff --git a/libs/aws/langchain_aws/embeddings/bedrock.py b/libs/aws/langchain_aws/embeddings/bedrock.py index 9c0f3a93..955fd778 100644 --- a/libs/aws/langchain_aws/embeddings/bedrock.py +++ b/libs/aws/langchain_aws/embeddings/bedrock.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional import numpy as np from langchain_core.embeddings import Embeddings @@ -151,13 +151,20 @@ def _cohere_multi_embedding(self, texts: List[str]) -> List[float]: """Call out to Cohere Bedrock embedding endpoint with multiple inputs.""" # replace newlines, which can negatively affect performance. texts = [text.replace(os.linesep, " ") for text in texts] + results = [] + + # Iterate through the list of strings in batches + for text_batch in _batch_cohere_embedding_texts(texts): + batch_embeddings = self._invoke_model( + input_body={ + "input_type": "search_document", + "texts": text_batch, + } + ).get("embeddings") - return self._invoke_model( - input_body={ - "input_type": "search_document", - "texts": texts, - } - ).get("embeddings") + results += batch_embeddings + + return results def _invoke_model(self, input_body: Dict[str, Any] = {}) -> Dict[str, Any]: if self.model_kwargs: @@ -262,3 +269,39 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: result = await asyncio.gather(*[self.aembed_query(text) for text in texts]) return list(result) + + +def _batch_cohere_embedding_texts(texts: List[str]) -> Generator[List[str], None, None]: + """Batches a set of texts into chunks that are acceptable for the Cohere embedding API: + chunks of at most 96 items, or 2048 characters.""" + + # Cohere embeddings want a maximum of 96 items and 2048 characters + max_items = 96 + max_chars = 2048 + + # Initialize batches + current_batch = [] + current_chars = 0 + + for text in texts: + text_len = len(text) + + if text_len > max_chars: + raise ValueError( + "The Cohere embedding API does not support texts longer than 2048 characters." + ) + + # Check if adding the current string would exceed the limits + if len(current_batch) >= max_items or current_chars + text_len > max_chars: + # Process the current batch if limits are exceeded + yield current_batch + # Start a new batch + current_batch = [] + current_chars = 0 + + # Otherwise, add the string to the current batch + current_batch.append(text) + current_chars += text_len + + if current_batch: + yield current_batch diff --git a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py index b95002b5..06e2e9b8 100644 --- a/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py +++ b/libs/aws/tests/integration_tests/embeddings/test_bedrock_embeddings.py @@ -3,6 +3,7 @@ import pytest from langchain_aws import BedrockEmbeddings +from langchain_aws.embeddings import bedrock @pytest.fixture @@ -126,3 +127,35 @@ def test_bedrock_cohere_embedding_documents_multiple(cohere_embeddings_v3) -> No assert len(output[0]) == 1024 assert len(output[1]) == 1024 assert len(output[2]) == 1024 + + +@pytest.mark.scheduled +def test_bedrock_cohere_batching() -> None: + # Test maximum text batch + documents = [f"{val}" for val in range(200)] + assert len(list(bedrock._batch_cohere_embedding_texts(documents))) == 3 + + # Test large character batch + large_char_batch = ["foo", "bar", "a" * 2045, "baz"] + assert list(bedrock._batch_cohere_embedding_texts(large_char_batch)) == [ + ["foo", "bar"], + ["a" * 2045, "baz"], + ] + + # Should be fine with exactly 2048 characters + assert list(bedrock._batch_cohere_embedding_texts(["a" * 2048])) == [["a" * 2048]] + + # But raise an error if it's more than that + with pytest.raises(ValueError): + list(bedrock._batch_cohere_embedding_texts(["a" * 2049])) + + +@pytest.mark.scheduled +def test_bedrock_cohere_embedding_large_document_set(cohere_embeddings_v3) -> None: + lots_of_documents = 200 + documents = [f"text_{val}" for val in range(lots_of_documents)] + output = cohere_embeddings_v3.embed_documents(documents) + assert len(output) == 200 + assert len(output[0]) == 1024 + assert len(output[1]) == 1024 + assert len(output[2]) == 1024