diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index df4e65c7..83688392 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -35,6 +35,7 @@ EDGE_HYBRID_SEARCH_NODE_DISTANCE, EDGE_HYBRID_SEARCH_RRF, ) +from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, get_communities_by_nodes, @@ -625,6 +626,7 @@ async def search( center_node_uuid: str | None = None, group_ids: list[str] | None = None, num_results=DEFAULT_SEARCH_LIMIT, + search_filter: SearchFilters | None = None, ) -> list[EntityEdge]: """ Perform a hybrid search on the knowledge graph. @@ -670,6 +672,7 @@ async def search( query, group_ids, search_config, + search_filter if search_filter is not None else SearchFilters(), center_node_uuid, ) ).edges @@ -683,6 +686,7 @@ async def _search( group_ids: list[str] | None = None, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, + search_filter: SearchFilters | None = None, ) -> SearchResults: return await search( self.driver, @@ -691,6 +695,7 @@ async def _search( query, group_ids, config, + search_filter if search_filter is not None else SearchFilters(), center_node_uuid, bfs_origin_node_uuids, ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index c1b134c3..71076b9f 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -39,6 +39,7 @@ SearchConfig, SearchResults, ) +from graphiti_core.search.search_filters import SearchFilters from graphiti_core.search.search_utils import ( community_fulltext_search, community_similarity_search, @@ -64,6 +65,7 @@ async def search( query: str, group_ids: list[str] | None, config: SearchConfig, + search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, ) -> SearchResults: @@ -86,6 +88,7 @@ async def search( query_vector, group_ids, config.edge_config, + search_filter, center_node_uuid, bfs_origin_node_uuids, config.limit, @@ -133,6 +136,7 @@ async def edge_search( query_vector: list[float], group_ids: list[str] | None, config: EdgeSearchConfig | None, + search_filter: SearchFilters, center_node_uuid: str | None = None, bfs_origin_node_uuids: list[str] | None = None, limit=DEFAULT_SEARCH_LIMIT, @@ -143,11 +147,20 @@ async def edge_search( search_results: list[list[EntityEdge]] = list( await semaphore_gather( *[ - edge_fulltext_search(driver, query, group_ids, 2 * limit), + edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit), edge_similarity_search( - driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score + driver, + query_vector, + None, + None, + search_filter, + group_ids, + 2 * limit, + config.sim_min_score, + ), + edge_bfs_search( + driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit ), - edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit), ] ) ) @@ -155,7 +168,9 @@ async def edge_search( if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None: source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result] search_results.append( - await edge_bfs_search(driver, source_node_uuids, config.bfs_max_depth, 2 * limit) + await edge_bfs_search( + driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit + ) ) edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} diff --git a/graphiti_core/search/search_filters.py b/graphiti_core/search/search_filters.py new file mode 100644 index 00000000..95710805 --- /dev/null +++ b/graphiti_core/search/search_filters.py @@ -0,0 +1,152 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field +from typing_extensions import LiteralString + + +class ComparisonOperator(Enum): + equals = '=' + not_equals = '<>' + greater_than = '>' + less_than = '<' + greater_than_equal = '>=' + less_than_equal = '<=' + + +class DateFilter(BaseModel): + date: datetime = Field(description='A datetime to filter on') + comparison_operator: ComparisonOperator = Field( + description='Comparison operator for date filter' + ) + + +class SearchFilters(BaseModel): + valid_at: list[list[DateFilter]] | None = Field(default=None) + invalid_at: list[list[DateFilter]] | None = Field(default=None) + created_at: list[list[DateFilter]] | None = Field(default=None) + expired_at: list[list[DateFilter]] | None = Field(default=None) + + +def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]: + filter_query: LiteralString = '' + filter_params: dict[str, Any] = {} + + if filters.valid_at is not None: + valid_at_filter = 'AND (' + for i, or_list in enumerate(filters.valid_at): + for j, date_filter in enumerate(or_list): + filter_params['valid_at_' + str(j)] = date_filter.date + + and_filters = [ + '(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})' + for j, date_filter in enumerate(or_list) + ] + and_filter_query = '' + for j, and_filter in enumerate(and_filters): + and_filter_query += and_filter + if j != len(and_filter_query) - 1: + and_filter_query += ' AND ' + + valid_at_filter += and_filter_query + + if i == len(or_list) - 1: + valid_at_filter += ')' + else: + valid_at_filter += ' OR ' + + filter_query += valid_at_filter + + if filters.invalid_at is not None: + invalid_at_filter = 'AND (' + for i, or_list in enumerate(filters.invalid_at): + for j, date_filter in enumerate(or_list): + filter_params['invalid_at_' + str(j)] = date_filter.date + + and_filters = [ + '(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})' + for j, date_filter in enumerate(or_list) + ] + and_filter_query = '' + for j, and_filter in enumerate(and_filters): + and_filter_query += and_filter + if j != len(and_filter_query) - 1: + and_filter_query += ' AND ' + + invalid_at_filter += and_filter_query + + if i == len(or_list) - 1: + invalid_at_filter += ')' + else: + invalid_at_filter += ' OR ' + + filter_query += invalid_at_filter + + if filters.created_at is not None: + created_at_filter = 'AND (' + for i, or_list in enumerate(filters.created_at): + for j, date_filter in enumerate(or_list): + filter_params['created_at_' + str(j)] = date_filter.date + + and_filters = [ + '(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})' + for j, date_filter in enumerate(or_list) + ] + and_filter_query = '' + for j, and_filter in enumerate(and_filters): + and_filter_query += and_filter + if j != len(and_filter_query) - 1: + and_filter_query += ' AND ' + + created_at_filter += and_filter_query + + if i == len(or_list) - 1: + created_at_filter += ')' + else: + created_at_filter += ' OR ' + + filter_query += created_at_filter + + if filters.expired_at is not None: + expired_at_filter = 'AND (' + for i, or_list in enumerate(filters.expired_at): + for j, date_filter in enumerate(or_list): + filter_params['expired_at_' + str(j)] = date_filter.date + + and_filters = [ + '(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})' + for j, date_filter in enumerate(or_list) + ] + and_filter_query = '' + for j, and_filter in enumerate(and_filters): + and_filter_query += and_filter + if j != len(and_filter_query) - 1: + and_filter_query += ' AND ' + + expired_at_filter += and_filter_query + + if i == len(or_list) - 1: + expired_at_filter += ')' + else: + expired_at_filter += ' OR ' + + filter_query += expired_at_filter + + return filter_query, filter_params diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 42c52ab7..c4d44fd1 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -38,6 +38,7 @@ get_community_node_from_record, get_entity_node_from_record, ) +from graphiti_core.search.search_filters import SearchFilters, search_filter_query_constructor logger = logging.getLogger(__name__) @@ -136,6 +137,7 @@ async def get_communities_by_nodes( async def edge_fulltext_search( driver: AsyncDriver, query: str, + search_filter: SearchFilters, group_ids: list[str] | None = None, limit=RELEVANT_SCHEMA_LIMIT, ) -> list[EntityEdge]: @@ -144,28 +146,36 @@ async def edge_fulltext_search( if fuzzy_query == '': return [] - cypher_query = Query(""" + filter_query, filter_params = search_filter_query_constructor(search_filter) + + cypher_query = Query( + """ CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit}) - YIELD relationship AS r, score - WITH r, score, startNode(r) AS n, endNode(r) AS m - RETURN - r.uuid AS uuid, - r.group_id AS group_id, - n.uuid AS source_node_uuid, - m.uuid AS target_node_uuid, - r.created_at AS created_at, - r.name AS name, - r.fact AS fact, - r.fact_embedding AS fact_embedding, - r.episodes AS episodes, - r.expired_at AS expired_at, - r.valid_at AS valid_at, - r.invalid_at AS invalid_at - ORDER BY score DESC LIMIT $limit - """) + YIELD relationship AS rel, score + MATCH (:ENTITY)-[r:RELATES_TO]->(:ENTITY) + WHERE r.group_id IN $group_ids""" + + filter_query + + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m + RETURN + r.uuid AS uuid, + r.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC LIMIT $limit + """ + ) records, _, _ = await driver.execute_query( cypher_query, + filter_params, query=fuzzy_query, group_ids=group_ids, limit=limit, @@ -183,6 +193,7 @@ async def edge_similarity_search( search_vector: list[float], source_node_uuid: str | None, target_node_uuid: str | None, + search_filter: SearchFilters, group_ids: list[str] | None = None, limit: int = RELEVANT_SCHEMA_LIMIT, min_score: float = DEFAULT_MIN_SCORE, @@ -194,6 +205,9 @@ async def edge_similarity_search( query_params: dict[str, Any] = {} + filter_query, filter_params = search_filter_query_constructor(search_filter) + query_params.update(filter_params) + group_filter_query: LiteralString = '' if group_ids is not None: group_filter_query += 'WHERE r.group_id IN $group_ids' @@ -209,9 +223,10 @@ async def edge_similarity_search( query: LiteralString = ( """ - MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) - """ + MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity) + """ + group_filter_query + + filter_query + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score WHERE score > $min_score RETURN @@ -254,17 +269,25 @@ async def edge_bfs_search( driver: AsyncDriver, bfs_origin_node_uuids: list[str] | None, bfs_max_depth: int, + search_filter: SearchFilters, limit: int, ) -> list[EntityEdge]: # vector similarity search over embedded facts if bfs_origin_node_uuids is None: return [] - query = Query(""" + filter_query, filter_params = search_filter_query_constructor(search_filter) + + query = Query( + """ UNWIND $bfs_origin_node_uuids AS origin_uuid MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity) UNWIND relationships(path) AS rel - MATCH ()-[r:RELATES_TO {uuid: rel.uuid}]-() + MATCH ()-[r:RELATES_TO]-() + WHERE r.uuid = rel.uuid + """ + + filter_query + + """ RETURN DISTINCT r.uuid AS uuid, r.group_id AS group_id, @@ -279,10 +302,12 @@ async def edge_bfs_search( r.valid_at AS valid_at, r.invalid_at AS invalid_at LIMIT $limit - """) + """ + ) records, _, _ = await driver.execute_query( query, + filter_params, bfs_origin_node_uuids=bfs_origin_node_uuids, depth=bfs_max_depth, limit=limit, @@ -626,6 +651,7 @@ async def get_relevant_edges( edge.fact_embedding, source_node_uuid, target_node_uuid, + SearchFilters(), [edge.group_id], limit, )