Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node label filters #265

Merged
merged 7 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ async def add_episode_endpoint(episode_data: EpisodeData):
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
*[
get_relevant_nodes(self.driver, SearchFilters(), [node])
for node in extracted_nodes
]
)
)

Expand Down Expand Up @@ -732,8 +735,8 @@ async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_no
self.llm_client,
[source_node, target_node],
[
await get_relevant_nodes(self.driver, [source_node]),
await get_relevant_nodes(self.driver, [target_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
],
)

Expand Down
10 changes: 7 additions & 3 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
query_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
Expand Down Expand Up @@ -233,6 +234,7 @@
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
Expand All @@ -243,11 +245,13 @@
search_results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
node_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
),
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)
Expand All @@ -255,7 +259,7 @@
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit)

Check failure on line 262 in graphiti_core/search/search.py

View workflow job for this annotation

GitHub Actions / mypy

call-arg

Missing positional argument "limit" in call to "node_bfs_search"

Check failure on line 262 in graphiti_core/search/search.py

View workflow job for this annotation

GitHub Actions / mypy

arg-type

Argument 3 to "node_bfs_search" has incompatible type "int"; expected "SearchFilters"
)

search_result_uuids = [[node.uuid for node in result] for result in search_results]
Expand Down
21 changes: 20 additions & 1 deletion graphiti_core/search/search_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,32 @@ class DateFilter(BaseModel):


class SearchFilters(BaseModel):
node_labels: list[str] | None = Field(
default=None, description='List of node labels to filter on'
)
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)


def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
def node_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}

if filters.node_labels is not None:
node_labels = ':'.join(filters.node_labels)
node_label_filter = 'AND n:' + node_labels
filter_query += node_label_filter

return filter_query, filter_params


def edge_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}

Expand Down
75 changes: 55 additions & 20 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
get_community_node_from_record,
get_entity_node_from_record,
)
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
from graphiti_core.search.search_filters import (
SearchFilters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,7 +152,7 @@ async def edge_fulltext_search(
if fuzzy_query == '':
return []

filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)

cypher_query = Query(
"""
Expand Down Expand Up @@ -207,7 +211,7 @@ async def edge_similarity_search(

query_params: dict[str, Any] = {}

filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params)

group_filter_query: LiteralString = ''
Expand All @@ -225,8 +229,8 @@ async def edge_similarity_search(

query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
Expand Down Expand Up @@ -278,7 +282,7 @@ async def edge_bfs_search(
if bfs_origin_node_uuids is None:
return []

filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)

query = Query(
"""
Expand Down Expand Up @@ -325,6 +329,7 @@ async def edge_bfs_search(
async def node_fulltext_search(
driver: AsyncDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
Expand All @@ -333,10 +338,17 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []

filter_query, filter_params = node_search_filter_query_constructor(search_filter)

records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
YIELD node AS node, score
MATCH (n:Entity)
WHERE n.uuid = node.uuid
"""
+ filter_query
+ """
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
Expand All @@ -349,6 +361,7 @@ async def node_fulltext_search(
ORDER BY score DESC
LIMIT $limit
""",
filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
Expand All @@ -363,6 +376,7 @@ async def node_fulltext_search(
async def node_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
Expand All @@ -379,12 +393,16 @@ async def node_similarity_search(
group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids

filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query_params.update(filter_params)

records, _, _ = await driver.execute_query(
runtime_query
+ """
MATCH (n:Entity)
"""
+ group_filter_query
+ filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
Expand Down Expand Up @@ -416,28 +434,36 @@ async def node_similarity_search(
async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int,
limit: int,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
return []

filter_query, filter_params = node_search_filter_query_constructor(search_filter)

records, _, _ = await driver.execute_query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
WHERE n.group_id = origin.group_id
"""
+ filter_query
+ """
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
Expand Down Expand Up @@ -539,6 +565,7 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
Expand Down Expand Up @@ -583,8 +610,14 @@ async def hybrid_node_search(
start = time()
results: list[list[EntityNode]] = list(
await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
*[
node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
for q in queries
],
*[
node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
for e in embeddings
],
)
)

Expand All @@ -604,6 +637,7 @@ async def hybrid_node_search(

async def get_relevant_nodes(
driver: AsyncDriver,
search_filter: SearchFilters,
nodes: list[EntityNode],
) -> list[EntityNode]:
"""
Expand Down Expand Up @@ -635,6 +669,7 @@ async def get_relevant_nodes(
[node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None],
driver,
search_filter,
[node.group_id for node in nodes],
)

Expand Down
Loading