Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
karanataryn committed Feb 20, 2025
2 parents 1bb8b81 + 49f6bf9 commit 4e7b101
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 26 deletions.
31 changes: 25 additions & 6 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,14 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int = 20, init_mode: str = "random", epsilon: float = 1e-4):
def kmeans(
self,
K: int,
iterations: int = 20,
init_mode: str = "random",
epsilon: float = 1e-4,
field_name: Optional[str] = None,
):
"""
Apply kmeans over embedding field
Expand All @@ -929,23 +936,35 @@ def kmeans(self, K: int, iterations: int = 20, init_mode: str = "random", epsilo
iterations: the max iteration runs before converge
init_mode: how the initial centroids are select
epsilon: the condition for determining if it's converged
field_name: the field used to run kmeans, use default embedding if it's None
Return a list of max K centroids
"""

def filter_meta(row):
doc = Document.from_row(row)
return not isinstance(doc, MetadataDocument)

def init_embedding(row):
doc = Document.from_row(row)
return {"vector": doc.embedding, "cluster": -1}
return (
{"vector": doc.embedding, "cluster": -1}
if field_name is None
else {"vector": doc[field_name], "cluster": -1}
)

embeddings = self.plan.execute().map(init_embedding).materialize()
embeddings = self.plan.execute().filter(filter_meta).map(init_embedding)

initial_centroids = KMeans.init(embeddings, K, init_mode)
centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon)
return centroids

def clustering(self, centroids, cluster_field_name, **resource_args) -> "DocSet":
def clustering(self, centroids, cluster_field_name, field_name=None, **resource_args) -> "DocSet":
# TODO, need to add field for do the clustering
def cluster(doc: Document) -> Document:
idx = KMeans.closest(doc.embedding, centroids)
doc[cluster_field_name] = idx
if not isinstance(doc, MetadataDocument):
embedding = doc[field_name] if field_name else doc.embedding
idx = KMeans.closest(embedding, centroids)
doc[cluster_field_name] = idx
return doc

from sycamore.transforms import Map
Expand Down
9 changes: 7 additions & 2 deletions lib/sycamore/sycamore/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ray.data.aggregate import AggregateFn

from sycamore import DocSet
from sycamore.data import Document
from sycamore.data import Document, MetadataDocument


class GroupedData:
Expand All @@ -14,7 +14,12 @@ def __init__(self, docset: DocSet, key):

def aggregate(self, f: "AggregateFn") -> DocSet:
dataset = self._docset.plan.execute()
grouped = dataset.map(Document.from_row).groupby(self._key)

def filter_meta(row):
doc = Document.from_row(row)
return not isinstance(doc, MetadataDocument)

grouped = dataset.filter(filter_meta).map(Document.from_row).groupby(self._key)
aggregated = grouped.aggregate(f)

def to_doc(row: dict):
Expand Down
17 changes: 10 additions & 7 deletions lib/sycamore/sycamore/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from typing import Any, Optional, Union
import os
import io

from sycamore.llms.llms import LLM
from sycamore.llms.prompts.prompts import RenderedPrompt
Expand All @@ -22,10 +23,10 @@ class GeminiModels(Enum):
"""Represents available Gemini models. More info: https://googleapis.github.io/python-genai/"""

# Note that the models available on a given Gemini account may vary.
GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash-exp", is_chat=True)
GEMINI_2_FLASH = GeminiModel(name="gemini-2.0-flash", is_chat=True)
GEMINI_2_FLASH_LITE = GeminiModel(name="gemini-2.0-flash-lite-preview-02-05", is_chat=True)
GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp", is_chat=True)
GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp", is_chat=True)
GEMINI_2_FLASH_THINKING = GeminiModel(name="gemini-2.0-flash-thinking-exp-01-21", is_chat=True)
GEMINI_2_PRO = GeminiModel(name="gemini-2.0-pro-exp-02-05", is_chat=True)

@classmethod
def from_name(cls, name: str):
Expand Down Expand Up @@ -86,7 +87,7 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
if prompt.response_format:
config["response_mime_type"] = "application/json"
config["response_schema"] = prompt.response_format
content_list = []
content_list: list[types.Content] = []
for message in prompt.messages:
if message.role == "system":
config["system_message"] = message.content
Expand All @@ -95,13 +96,15 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role)
if message.images:
for image in message.images:
image_bytes = image.convert("RGB").tobytes()
content.parts.append(types.Part.from_bytes(image_bytes, media_type="image/png"))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
content.parts.append(types.Part.from_bytes(data=image_bytes, mime_type="image/png"))
content_list.append(content)
kwargs["config"] = None
if config:
kwargs["config"] = types.GenerateContentConfig(**config)
kwargs["content"] = content
kwargs["content"] = content_list
return kwargs

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
Expand Down
18 changes: 14 additions & 4 deletions lib/sycamore/sycamore/query/execution/sycamore_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sycamore.query.execution.physical_operator import PhysicalOperator
from sycamore.query.operators.math import Math
from sycamore.query.operators.sort import Sort
from sycamore.query.operators.top_k import TopK
from sycamore.query.operators.top_k import TopK, GroupByCount
from sycamore.query.operators.field_in import FieldIn
from sycamore.query.execution.physical_operator import MathOperator
from sycamore.query.execution.sycamore_operator import (
Expand All @@ -35,6 +35,7 @@
SycamoreLimit,
SycamoreFieldIn,
SycamoreQueryVectorDatabase,
SycamoreGroupByCount,
)

log = structlog.get_logger(__name__)
Expand Down Expand Up @@ -184,6 +185,14 @@ def process_node(
inputs=inputs,
trace_dir=self.trace_dir,
)
elif isinstance(logical_node, GroupByCount):
operation = SycamoreGroupByCount(
context=self.context,
logical_node=logical_node,
query_id=query_id,
inputs=inputs,
trace_dir=self.trace_dir,
)
elif isinstance(logical_node, FieldIn):
operation = SycamoreFieldIn(
context=self.context,
Expand All @@ -206,9 +215,10 @@ def process_node(
else:
raise ValueError(f"Unsupported node type: {str(logical_node)}")

code, imports = operation.script(output_var=(self.OUTPUT_VAR_NAME if is_result_node else None))
self.imports += imports
self.node_id_to_code[logical_node.node_id] = code
if self.codegen_mode:
code, imports = operation.script(output_var=(self.OUTPUT_VAR_NAME if is_result_node else None))
self.imports += imports
self.node_id_to_code[logical_node.node_id] = code
self.node_id_to_node[logical_node.node_id] = logical_node

operation_result = "visited"
Expand Down
49 changes: 48 additions & 1 deletion lib/sycamore/sycamore/query/execution/sycamore_operator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from abc import abstractmethod
from typing import Any, Optional, List, Dict, Tuple

Expand All @@ -17,7 +18,7 @@
from sycamore.query.operators.llm_filter import LlmFilter
from sycamore.query.operators.summarize_data import SummarizeData
from sycamore.query.operators.query_database import QueryDatabase, QueryVectorDatabase
from sycamore.query.operators.top_k import TopK
from sycamore.query.operators.top_k import TopK, GroupByCount
from sycamore.query.operators.field_in import FieldIn
from sycamore.query.operators.sort import Sort

Expand Down Expand Up @@ -613,6 +614,52 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No
return result, []


class SycamoreGroupByCount(SycamoreOperator):
"""
Note: top_k clustering only operators on properties, it will not cluster on text_representation currently.
Return the Top-K values from a DocSet
"""

def __init__(
self,
context: Context,
logical_node: GroupByCount,
query_id: str,
inputs: Optional[List[Any]] = None,
trace_dir: Optional[str] = None,
) -> None:
super().__init__(context, logical_node, query_id, inputs, trace_dir=trace_dir)

def execute(self) -> Any:
assert self.inputs and len(self.inputs) == 1, "GroupByCount requires 1 input node"
assert isinstance(self.inputs[0], DocSet), "GroupByCount requires a DocSet input"
# load into local vars for Ray serialization magic
logical_node = self.logical_node
assert isinstance(logical_node, GroupByCount)

entity_name = logical_node.entity_name
embed_name = logical_node.embed_name

embedder = get_val_from_context(context=self.context, val_key="text_embedder", param_names=["opensearch"])
embedder = copy.copy(embedder)
assert embedder and isinstance(embedder, Embedder), "GroupByCount requires an Embedder in the context"
embedder.embed_name = (entity_name, embed_name)

cluster_field_name = logical_node.cluster_field_name
descending = logical_node.descending
K = logical_node.K

docset = self.inputs[0].embed(embedder)
centroids = docset.kmeans(K=K * 2, field_name=embed_name)
clustered = docset.clustering(centroids=centroids, cluster_field_name=cluster_field_name, field_name=embed_name)
result = clustered.groupby(cluster_field_name).count().sort(descending, "properties.count", 0).limit(K)

return result

def script(self, input_var: Optional[str] = None, output_var: Optional[str] = None) -> Tuple[str, List[str]]:
raise Exception("GroupByCount not implemented for codegen")


class SycamoreFieldIn(SycamoreOperator):
"""
Return 2 DocSets joined
Expand Down
21 changes: 21 additions & 0 deletions lib/sycamore/sycamore/query/operators/top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,24 @@ class TopK(Node):
"""An instruction of what the groups should be about if llm_cluster is True. E.g. if the
purpose of this operation is to find the top 2 most frequent cities, llm_cluster_instruction
could be 'Form groups of different food'"""


class GroupByCount(Node):
"""Finds the top K frequent occurences of values for a particular field.
Returns a database with ONLY 2 FIELDS: "properties.key" (which corresponds to unique values of
*field*) and "properties.count" (which contains the counts corresponding to unique values
of *field*).
"""

entity_name: Optional[str] = None
embed_name: Optional[str] = None
"""The database field to find the top K occurences for."""

cluster_field_name: str = "centroids"

K: int = 5

descending: bool = False
"""If True, will return the top K most common occurrences. If False, will return the top K
least common occurrences."""
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest

from sycamore import DocSet
from sycamore.query.operators.top_k import GroupByCount
from sycamore.query.strategy import QueryPlanStrategy

from sycamore.query.client import SycamoreQueryClient
from sycamore.query.operators.query_database import QueryVectorDatabase
from sycamore.query.operators.query_database import QueryVectorDatabase, QueryDatabase


class TestSycamoreQuery:
Expand Down Expand Up @@ -70,8 +71,8 @@ def test_dry_run(self, query_integration_test_index: str, dry_run: bool):
assert len(result.result) > 0
assert ray.is_initialized()

@pytest.mark.parametrize("codegen_mode", [True, False])
def test_vector_search(self, query_integration_test_index: str, codegen_mode: bool):
@pytest.mark.parametrize("codegen_mode", [False])
def test_vector_search(self, query_integration_test_index, codegen_mode: bool):
""" """

client = SycamoreQueryClient(query_plan_strategy=QueryPlanStrategy())
Expand All @@ -88,3 +89,31 @@ def test_vector_search(self, query_integration_test_index: str, codegen_mode: bo
assert isinstance(result.result, DocSet)
docs = result.result.take_all()
assert len(docs) > 0

def test_vector_search_2(self, query_integration_test_index: str):

client = SycamoreQueryClient(query_plan_strategy=QueryPlanStrategy())
schema = client.get_opensearch_schema("ntsb-accident-cause")
plan = client.generate_plan(
"What was the most common cause of accidents in the NTSB incident reports?",
"ntsb-accident-cause",
schema,
natural_language_response=False,
)
assert len(plan.nodes) == 2
assert isinstance(plan.nodes[0], QueryDatabase)
plan.nodes[1] = GroupByCount(
node_type="GroupByCount",
node_id=1,
description="Find the most common cause of accidents",
inputs=[0],
entity_name="properties.cause",
embed_name="cause",
cluster_field_name="centroids",
K=5,
descending=True,
)
result = client.run_plan(plan, codegen_mode=False)
assert isinstance(result.result, DocSet)
docs = result.result.take_all()
assert len(docs) > 0
5 changes: 3 additions & 2 deletions lib/sycamore/sycamore/transforms/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ def converged(last_ones, next_ones, epsilon):
@staticmethod
def random_init(embeddings, K):
count = embeddings.count()
assert count > 0 and K < count
K = K if count > K else count
assert count > 0
fraction = min(2 * K / count, 1.0)

candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()]
candidates.sort()
from itertools import groupby

uniques = [key for key, _ in groupby(candidates)]
assert len(uniques) >= K
# assert len(uniques) >= K

centroids = random.sample(uniques, K)
return centroids
Expand Down
Loading

0 comments on commit 4e7b101

Please sign in to comment.