diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index df12d98d..ffed107d 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -351,7 +351,10 @@ async def add_episode_endpoint(episode_data: EpisodeData): # Find relevant nodes already in the graph existing_nodes_lists: list[list[EntityNode]] = list( await semaphore_gather( - *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes] + *[ + get_relevant_nodes(self.driver, SearchFilters(), [node]) + for node in extracted_nodes + ] ) ) @@ -732,8 +735,8 @@ async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_no self.llm_client, [source_node, target_node], [ - await get_relevant_nodes(self.driver, [source_node]), - await get_relevant_nodes(self.driver, [target_node]), + await get_relevant_nodes(self.driver, SearchFilters(), [source_node]), + await get_relevant_nodes(self.driver, SearchFilters(), [target_node]), ], ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 71076b9f..4eb9a53b 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -100,6 +100,7 @@ async def search( query_vector, group_ids, config.node_config, + search_filter, center_node_uuid, bfs_origin_node_uuids, config.limit, @@ -233,6 +234,7 @@ async def node_search( query_vector: list[float], group_ids: list[str] | None, config: NodeSearchConfig | None, + search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, @@ -243,11 +245,13 @@ async def node_search( search_results: list[list[EntityNode]] = list( await semaphore_gather( *[ - node_fulltext_search(driver, query, group_ids, 2 * limit), + node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), node_similarity_search( - driver, query_vector, group_ids, 2 * limit, config.sim_min_score + driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score + ), + node_bfs_search( + driver, bfs_origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit ), - node_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit), ] ) ) @@ -255,7 +259,9 @@ async def node_search( if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: origin_node_uuids = [node.uuid for result in search_results for node in result] search_results.append( - await node_bfs_search(driver, origin_node_uuids, config.bfs_max_depth, 2 * limit) + await node_bfs_search( + driver, origin_node_uuids, search_filter, config.bfs_max_depth, 2 * limit + ) ) search_result_uuids = [[node.uuid for node in result] for result in search_results] diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py index 95710805..2d77e641 100644 --- a/graphiti_core/search/search_filters.py +++ b/graphiti_core/search/search_filters.py @@ -39,18 +39,37 @@ class DateFilter(BaseModel): class SearchFilters(BaseModel): + node_labels: list[str] | None = Field( + default=None, description='List of node labels to filter on' + ) valid_at: list[list[DateFilter]] | None = Field(default=None) invalid_at: list[list[DateFilter]] | None = Field(default=None) created_at: list[list[DateFilter]] | None = Field(default=None) expired_at: list[list[DateFilter]] | None = Field(default=None) -def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]: +def node_search_filter_query_constructor( + filters: SearchFilters, +) -> tuple[LiteralString, dict[str, Any]]: + filter_query: LiteralString = '' + filter_params: dict[str, Any] = {} + + if filters.node_labels is not None: + node_labels = ':'.join(filters.node_labels) + node_label_filter = ' AND n:' + node_labels + filter_query += node_label_filter + + return filter_query, filter_params + + +def edge_search_filter_query_constructor( + filters: SearchFilters, +) -> tuple[LiteralString, dict[str, Any]]: filter_query: LiteralString = '' filter_params: dict[str, Any] = {} if filters.valid_at is not None: - valid_at_filter = 'AND (' + valid_at_filter = ' AND (' for i, or_list in enumerate(filters.valid_at): for j, date_filter in enumerate(or_list): filter_params['valid_at_' + str(j)] = date_filter.date @@ -75,7 +94,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri filter_query += valid_at_filter if filters.invalid_at is not None: - invalid_at_filter = 'AND (' + invalid_at_filter = ' AND (' for i, or_list in enumerate(filters.invalid_at): for j, date_filter in enumerate(or_list): filter_params['invalid_at_' + str(j)] = date_filter.date @@ -100,7 +119,7 @@ def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralStri filter_query += invalid_at_filter if filters.created_at is not None: - created_at_filter = 'AND (' + created_at_filter = ' AND (' for i, or_list in enumerate(filters.created_at): for j, date_filter in enumerate(or_list): filter_params['created_at_' + str(j)] = date_filter.date diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index ef9b6eb2..578e46f2 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -38,7 +38,11 @@ get_community_node_from_record, get_entity_node_from_record, ) -from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor +from graphiti_core.search.search_filters import ( + SearchFilters, + edge_search_filter_query_constructor, + node_search_filter_query_constructor, +) logger = logging.getLogger(__name__) @@ -148,7 +152,7 @@ async def edge_fulltext_search( if fuzzy_query == '': return [] - filter_query, filter_params = search_filter_query_constructor(search_filter) + filter_query, filter_params = edge_search_filter_query_constructor(search_filter) cypher_query = Query( """ @@ -207,7 +211,7 @@ async def edge_similarity_search( query_params: dict[str, Any] = {} - filter_query, filter_params = search_filter_query_constructor(search_filter) + filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query_params.update(filter_params) group_filter_query: LiteralString = '' @@ -225,8 +229,8 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score @@ -278,7 +282,7 @@ async def edge_bfs_search( if bfs_origin_node_uuids is None: return [] - filter_query, filter_params = search_filter_query_constructor(search_filter) + filter_query, filter_params = edge_search_filter_query_constructor(search_filter) query = Query( """ @@ -325,6 +329,7 @@ async def edge_bfs_search( async def node_fulltext_search( driver: AsyncDriver, query: str, + search_filter: SearchFilters, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: @@ -333,10 +338,17 @@ async def node_fulltext_search( if fuzzy_query == '': return [] + filter_query, filter_params = node_search_filter_query_constructor(search_filter) + records, _, _ = await driver.execute_query( """ CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit}) - YIELD node AS n, score + YIELD node AS node, score + MATCH (n:Entity) + WHERE n.uuid = node.uuid + """ + + filter_query + + """ RETURN n.uuid AS uuid, n.group_id AS group_id, @@ -349,6 +361,7 @@ async def node_fulltext_search( ORDER BY score DESC LIMIT $limit """, + filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, @@ -363,6 +376,7 @@ async def node_fulltext_search( async def node_similarity_search( driver: AsyncDriver, search_vector: list[float], + search_filter: SearchFilters, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, min_score: float = DEFAULT_MIN_SCORE, @@ -379,12 +393,16 @@ async def node_similarity_search( group_filter_query += 'WHERE n.group_id IN $group_ids' query_params['group_ids'] = group_ids + filter_query, filter_params = node_search_filter_query_constructor(search_filter) + query_params.update(filter_params) + records, _, _ = await driver.execute_query( runtime_query + """ MATCH (n:Entity) """ + group_filter_query + + filter_query + """ WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score WHERE score > $min_score @@ -416,6 +434,7 @@ async def node_similarity_search( async def node_bfs_search( driver: AsyncDriver, bfs_origin_node_uuids: list[str] | None, + search_filter: SearchFilters, bfs_max_depth: int, limit: int, ) -> list[EntityNode]: @@ -423,21 +442,28 @@ async def node_bfs_search( if bfs_origin_node_uuids is None: return [] + filter_query, filter_params = node_search_filter_query_constructor(search_filter) + records, _, _ = await driver.execute_query( """ UNWIND $bfs_origin_node_uuids AS origin_uuid MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) - RETURN DISTINCT - n.uuid As uuid, - n.group_id AS group_id, - n.name AS name, - n.name_embedding AS name_embedding, - n.created_at AS created_at, - n.summary AS summary, - labels(n) AS labels, - properties(n) AS attributes - LIMIT $limit - """, + WHERE n.group_id = origin.group_id + """ + + filter_query + + """ + RETURN DISTINCT + n.uuid As uuid, + n.group_id AS group_id, + n.name AS name, + n.name_embedding AS name_embedding, + n.created_at AS created_at, + n.summary AS summary, + labels(n) AS labels, + properties(n) AS attributes + LIMIT $limit + """, + filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, depth=bfs_max_depth, limit=limit, @@ -539,6 +565,7 @@ async def hybrid_node_search( queries: list[str], embeddings: list[list[float]], driver: AsyncDriver, + search_filter: SearchFilters, group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, ) -> list[EntityNode]: @@ -583,8 +610,14 @@ async def hybrid_node_search( start = time() results: list[list[EntityNode]] = list( await semaphore_gather( - *[node_fulltext_search(driver, q, group_ids, 2 * limit) for q in queries], - *[node_similarity_search(driver, e, group_ids, 2 * limit) for e in embeddings], + *[ + node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit) + for q in queries + ], + *[ + node_similarity_search(driver, e, search_filter, group_ids, 2 * limit) + for e in embeddings + ], ) ) @@ -604,6 +637,7 @@ async def hybrid_node_search( async def get_relevant_nodes( driver: AsyncDriver, + search_filter: SearchFilters, nodes: list[EntityNode], ) -> list[EntityNode]: """ @@ -635,6 +669,7 @@ async def get_relevant_nodes( [node.name for node in nodes], [node.name_embedding for node in nodes if node.name_embedding is not None], driver, + search_filter, [node.group_id for node in nodes], ) diff --git a/graphiti_core/utils/bulk_utils.py b/graphiti_core/utils/bulk_utils.py index b2340d63..3766dba7 100644 --- a/graphiti_core/utils/bulk_utils.py +++ b/graphiti_core/utils/bulk_utils.py @@ -37,6 +37,7 @@ EPISODIC_NODE_SAVE_BULK, ) from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode +from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import get_relevant_edges, get_relevant_nodes from graphiti_core.utils.datetime_utils import utc_now from graphiti_core.utils.maintenance.edge_operations import ( @@ -188,7 +189,7 @@ async def dedupe_nodes_bulk( existing_nodes_chunks: list[list[EntityNode]] = list( await semaphore_gather( - *[get_relevant_nodes(driver, node_chunk) for node_chunk in node_chunks] + *[get_relevant_nodes(driver, SearchFilters(), node_chunk) for node_chunk in node_chunks] ) ) diff --git a/pyproject.toml b/pyproject.toml index 790a3a72..3a919422 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "graphiti-core" -version = "0.7.0" +version = "0.7.1" description = "A temporal graph building library" authors = [ "Paul Paliychuk ", diff --git a/tests/utils/search/search_utils_test.py b/tests/utils/search/search_utils_test.py index 0a260919..02b9ac7d 100644 --- a/tests/utils/search/search_utils_test.py +++ b/tests/utils/search/search_utils_test.py @@ -3,6 +3,7 @@ import pytest from graphiti_core.nodes import EntityNode +from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import hybrid_node_search @@ -13,7 +14,7 @@ async def test_hybrid_node_search_deduplication(): # Mock the node_fulltext_search and entity_similarity_search functions with patch( - 'graphiti_core.search.search_utils.node_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: @@ -30,7 +31,7 @@ async def test_hybrid_node_search_deduplication(): # Call the function with test data queries = ['Alice', 'Bob'] embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] - results = await hybrid_node_search(queries, embeddings, mock_driver) + results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters()) # Assertions assert len(results) == 3 @@ -47,7 +48,7 @@ async def test_hybrid_node_search_empty_results(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.node_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: @@ -56,7 +57,7 @@ async def test_hybrid_node_search_empty_results(): queries = ['NonExistent'] embeddings = [[0.1, 0.2, 0.3]] - results = await hybrid_node_search(queries, embeddings, mock_driver) + results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters()) assert len(results) == 0 @@ -66,7 +67,7 @@ async def test_hybrid_node_search_only_fulltext(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.node_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: @@ -77,7 +78,7 @@ async def test_hybrid_node_search_only_fulltext(): queries = ['Alice'] embeddings = [] - results = await hybrid_node_search(queries, embeddings, mock_driver) + results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters()) assert len(results) == 1 assert results[0].name == 'Alice' @@ -90,7 +91,7 @@ async def test_hybrid_node_search_with_limit(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.node_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: @@ -111,7 +112,9 @@ async def test_hybrid_node_search_with_limit(): queries = ['Test'] embeddings = [[0.1, 0.2, 0.3]] limit = 1 - results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) + results = await hybrid_node_search( + queries, embeddings, mock_driver, SearchFilters(), ['1'], limit + ) # We expect 4 results because the limit is applied per search method # before deduplication, and we're not actually limiting the results @@ -120,8 +123,10 @@ async def test_hybrid_node_search_with_limit(): assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 # Verify that the limit was passed to the search functions - mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 2) - mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 2) + mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 2) + mock_similarity_search.assert_called_with( + mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 2 + ) @pytest.mark.asyncio @@ -129,7 +134,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): mock_driver = AsyncMock() with patch( - 'graphiti_core.search.search_utils.node_fulltext_search' + 'graphiti_core.search.search_utils.node_fulltext_search' ) as mock_fulltext_search, patch( 'graphiti_core.search.search_utils.node_similarity_search' ) as mock_similarity_search: @@ -145,7 +150,9 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): queries = ['Test'] embeddings = [[0.1, 0.2, 0.3]] limit = 2 - results = await hybrid_node_search(queries, embeddings, mock_driver, ['1'], limit) + results = await hybrid_node_search( + queries, embeddings, mock_driver, SearchFilters(), ['1'], limit + ) # We expect 3 results because: # 1. The limit of 2 is applied to each search method @@ -155,5 +162,7 @@ async def test_hybrid_node_search_with_limit_and_duplicates(): assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'} assert mock_fulltext_search.call_count == 1 assert mock_similarity_search.call_count == 1 - mock_fulltext_search.assert_called_with(mock_driver, 'Test', ['1'], 4) - mock_similarity_search.assert_called_with(mock_driver, [0.1, 0.2, 0.3], ['1'], 4) + mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 4) + mock_similarity_search.assert_called_with( + mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4 + )