Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node label filters #265

Merged
merged 7 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,10 @@ async def add_episode_endpoint(episode_data: EpisodeData):
# Find relevant nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
*[
get_relevant_nodes(self.driver, SearchFilters(), [node])
for node in extracted_nodes
]
)
)

Expand Down Expand Up @@ -732,8 +735,8 @@ async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_no
self.llm_client,
[source_node, target_node],
[
await get_relevant_nodes(self.driver, [source_node]),
await get_relevant_nodes(self.driver, [target_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [source_node]),
await get_relevant_nodes(self.driver, SearchFilters(), [target_node]),
],
)

Expand Down
14 changes: 10 additions & 4 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def search(
query_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
Expand Down Expand Up @@ -233,6 +234,7 @@ async def node_search(
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
Expand All @@ -243,19 +245,23 @@ async def node_search(
search_results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
node_fulltext_search(driver, query, group_ids, 2 * limit),
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
node_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
),
node_bfs_search(
driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
),
node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)

if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit)
await node_bfs_search(
driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit
)
)

search_result_uuids = [[node.uuid for node in result] for result in search_results]
Expand Down
27 changes: 23 additions & 4 deletions graphiti_core/search/search_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,37 @@ class DateFilter(BaseModel):


class SearchFilters(BaseModel):
node_labels: list[str] | None = Field(
default=None, description='List of node labels to filter on'
)
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)


def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
def node_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}

if filters.node_labels is not None:
node_labels = ':'.join(filters.node_labels)
node_label_filter = ' AND n:' + node_labels
filter_query += node_label_filter

return filter_query, filter_params


def edge_search_filter_query_constructor(
filters: SearchFilters,
) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}

if filters.valid_at is not None:
valid_at_filter = 'AND ('
valid_at_filter = ' AND ('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date
Expand All @@ -75,7 +94,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
filter_query += valid_at_filter

if filters.invalid_at is not None:
invalid_at_filter = 'AND ('
invalid_at_filter = ' AND ('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date
Expand All @@ -100,7 +119,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri
filter_query += invalid_at_filter

if filters.created_at is not None:
created_at_filter = 'AND ('
created_at_filter = ' AND ('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date
Expand Down
75 changes: 55 additions & 20 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
get_community_node_from_record,
get_entity_node_from_record,
)
from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor
from graphiti_core.search.search_filters import (
SearchFilters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,7 +152,7 @@ async def edge_fulltext_search(
if fuzzy_query == '':
return []

filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)

cypher_query = Query(
"""
Expand Down Expand Up @@ -207,7 +211,7 @@ async def edge_similarity_search(

query_params: dict[str, Any] = {}

filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
query_params.update(filter_params)

group_filter_query: LiteralString = ''
Expand All @@ -225,8 +229,8 @@ async def edge_similarity_search(

query: LiteralString = (
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
"""
+ group_filter_query
+ filter_query
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
Expand Down Expand Up @@ -278,7 +282,7 @@ async def edge_bfs_search(
if bfs_origin_node_uuids is None:
return []

filter_query, filter_params = search_filter_query_constructor(search_filter)
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)

query = Query(
"""
Expand Down Expand Up @@ -325,6 +329,7 @@ async def edge_bfs_search(
async def node_fulltext_search(
driver: AsyncDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
Expand All @@ -333,10 +338,17 @@ async def node_fulltext_search(
if fuzzy_query == '':
return []

filter_query, filter_params = node_search_filter_query_constructor(search_filter)

records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
YIELD node AS n, score
YIELD node AS node, score
MATCH (n:Entity)
WHERE n.uuid = node.uuid
"""
+ filter_query
+ """
RETURN
n.uuid AS uuid,
n.group_id AS group_id,
Expand All @@ -349,6 +361,7 @@ async def node_fulltext_search(
ORDER BY score DESC
LIMIT $limit
""",
filter_params,
query=fuzzy_query,
group_ids=group_ids,
limit=limit,
Expand All @@ -363,6 +376,7 @@ async def node_fulltext_search(
async def node_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
Expand All @@ -379,12 +393,16 @@ async def node_similarity_search(
group_filter_query += 'WHERE n.group_id IN $group_ids'
query_params['group_ids'] = group_ids

filter_query, filter_params = node_search_filter_query_constructor(search_filter)
query_params.update(filter_params)

records, _, _ = await driver.execute_query(
runtime_query
+ """
MATCH (n:Entity)
"""
+ group_filter_query
+ filter_query
+ """
WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
WHERE score > $min_score
Expand Down Expand Up @@ -416,28 +434,36 @@ async def node_similarity_search(
async def node_bfs_search(
driver: AsyncDriver,
bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int,
limit: int,
) -> list[EntityNode]:
# vector similarity search over entity names
if bfs_origin_node_uuids is None:
return []

filter_query, filter_params = node_search_filter_query_constructor(search_filter)

records, _, _ = await driver.execute_query(
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
WHERE n.group_id = origin.group_id
"""
+ filter_query
+ """
RETURN DISTINCT
n.uuid As uuid,
n.group_id AS group_id,
n.name AS name,
n.name_embedding AS name_embedding,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
LIMIT $limit
""",
filter_params,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
Expand Down Expand Up @@ -539,6 +565,7 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
Expand Down Expand Up @@ -583,8 +610,14 @@ async def hybrid_node_search(
start = time()
results: list[list[EntityNode]] = list(
await semaphore_gather(
*[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries],
*[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings],
*[
node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
for q in queries
],
*[
node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
for e in embeddings
],
)
)

Expand All @@ -604,6 +637,7 @@ async def hybrid_node_search(

async def get_relevant_nodes(
driver: AsyncDriver,
search_filter: SearchFilters,
nodes: list[EntityNode],
) -> list[EntityNode]:
"""
Expand Down Expand Up @@ -635,6 +669,7 @@ async def get_relevant_nodes(
[node.name for node in nodes],
[node.name_embedding for node in nodes if node.name_embedding is not None],
driver,
search_filter,
[node.group_id for node in nodes],
)

Expand Down
3 changes: 2 additions & 1 deletion graphiti_core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
EPISODIC_NODE_SAVE_BULK,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import (
Expand Down Expand Up @@ -188,7 +189,7 @@ async def dedupe_nodes_bulk(

existing_nodes_chunks: list[list[EntityNode]] = list(
await semaphore_gather(
*[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks]
*[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks]
)
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "graphiti-core"
version = "0.7.0"
version = "0.7.1"
description = "A temporal graph building library"
authors = [
"Paul Paliychuk <[email protected]>",
Expand Down
Loading