Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: text and character limits on Cohere embedding API #376

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional

import numpy as np
from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -151,13 +151,20 @@ def _cohere_multi_embedding(self, texts: List[str]) -> List[float]:
"""Call out to Cohere Bedrock embedding endpoint with multiple inputs."""
# replace newlines, which can negatively affect performance.
texts = [text.replace(os.linesep, " ") for text in texts]
results = []

# Iterate through the list of strings in batches
for text_batch in _batch_cohere_embedding_texts(texts):
batch_embeddings = self._invoke_model(
input_body={
"input_type": "search_document",
"texts": text_batch,
}
).get("embeddings")

return self._invoke_model(
input_body={
"input_type": "search_document",
"texts": texts,
}
).get("embeddings")
results += batch_embeddings

return results

def _invoke_model(self, input_body: Dict[str, Any] = {}) -> Dict[str, Any]:
if self.model_kwargs:
Expand Down Expand Up @@ -262,3 +269,39 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
result = await asyncio.gather(*[self.aembed_query(text) for text in texts])

return list(result)


def _batch_cohere_embedding_texts(texts: List[str]) -> Generator[List[str], None, None]:
"""Batches a set of texts into chunks that are acceptable for the Cohere embedding API:
chunks of at most 96 items, or 2048 characters."""

# Cohere embeddings want a maximum of 96 items and 2048 characters
max_items = 96
max_chars = 2048

# Initialize batches
current_batch = []
current_chars = 0

for text in texts:
text_len = len(text)

if text_len > max_chars:
raise ValueError(
"The Cohere embedding API does not support texts longer than 2048 characters."
)

# Check if adding the current string would exceed the limits
if len(current_batch) >= max_items or current_chars + text_len > max_chars:
# Process the current batch if limits are exceeded
yield current_batch
# Start a new batch
current_batch = []
current_chars = 0

# Otherwise, add the string to the current batch
current_batch.append(text)
current_chars += text_len

if current_batch:
yield current_batch
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from langchain_aws import BedrockEmbeddings
from langchain_aws.embeddings import bedrock


@pytest.fixture
Expand Down Expand Up @@ -126,3 +127,35 @@ def test_bedrock_cohere_embedding_documents_multiple(cohere_embeddings_v3) -> No
assert len(output[0]) == 1024
assert len(output[1]) == 1024
assert len(output[2]) == 1024


@pytest.mark.scheduled
def test_bedrock_cohere_batching() -> None:
# Test maximum text batch
documents = [f"{val}" for val in range(200)]
assert len(list(bedrock._batch_cohere_embedding_texts(documents))) == 3

# Test large character batch
large_char_batch = ["foo", "bar", "a" * 2045, "baz"]
assert list(bedrock._batch_cohere_embedding_texts(large_char_batch)) == [
["foo", "bar"],
["a" * 2045, "baz"],
]

# Should be fine with exactly 2048 characters
assert list(bedrock._batch_cohere_embedding_texts(["a" * 2048])) == [["a" * 2048]]

# But raise an error if it's more than that
with pytest.raises(ValueError):
list(bedrock._batch_cohere_embedding_texts(["a" * 2049]))


@pytest.mark.scheduled
def test_bedrock_cohere_embedding_large_document_set(cohere_embeddings_v3) -> None:
lots_of_documents = 200
documents = [f"text_{val}" for val in range(lots_of_documents)]
output = cohere_embeddings_v3.embed_documents(documents)
assert len(output) == 200
assert len(output[0]) == 1024
assert len(output[1]) == 1024
assert len(output[2]) == 1024