Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Date filters #240

Merged
merged 9 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
EDGE_HYBRID_SEARCH_NODE_DISTANCE,
EDGE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_communities_by_nodes,
Expand Down Expand Up @@ -625,6 +626,7 @@ async def search(
center_node_uuid: str | None = None,
group_ids: list[str] | None = None,
num_results=DEFAULT_SEARCH_LIMIT,
search_filter: SearchFilters | None = None,
) -> list[EntityEdge]:
"""
Perform a hybrid search on the knowledge graph.
Expand Down Expand Up @@ -670,6 +672,7 @@ async def search(
query,
group_ids,
search_config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
)
).edges
Expand All @@ -683,6 +686,7 @@ async def _search(
group_ids: list[str] | None = None,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
search_filter: SearchFilters | None = None,
) -> SearchResults:
return await search(
self.driver,
Expand All @@ -691,6 +695,7 @@ async def _search(
query,
group_ids,
config,
search_filter if search_filter is not None else SearchFilters(),
center_node_uuid,
bfs_origin_node_uuids,
)
Expand Down
23 changes: 19 additions & 4 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
Expand All @@ -64,6 +65,7 @@ async def search(
query: str,
group_ids: list[str] | None,
config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
) -> SearchResults:
Expand All @@ -86,6 +88,7 @@ async def search(
query_vector,
group_ids,
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
Expand Down Expand Up @@ -133,6 +136,7 @@ async def edge_search(
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
Expand All @@ -143,19 +147,30 @@ async def edge_search(
search_results: list[list[EntityEdge]] = list(
await semaphore_gather(
*[
edge_fulltext_search(driver, query, group_ids, 2 * limit),
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
edge_similarity_search(
driver, query_vector, None, None, group_ids, 2 * limit, config.sim_min_score
driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
),
edge_bfs_search(
driver, bfs_origin_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
),
edge_bfs_search(driver, bfs_origin_node_uuids, config.bfs_max_depth, 2 * limit),
]
)
)

if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append(
await edge_bfs_search(driver, source_node_uuids, config.bfs_max_depth, 2 * limit)
await edge_bfs_search(
driver, source_node_uuids, config.bfs_max_depth, search_filter, 2 * limit
)
)

edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
Expand Down
152 changes: 152 additions & 0 deletions graphiti_core/search/search_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from datetime import datetime
from enum import Enum
from typing import Any

from pydantic import BaseModel, Field
from typing_extensions import LiteralString


class ComparisonOperator(Enum):
equals = '='
not_equals = '<>'
greater_than = '>'
less_than = '<'
greater_than_equal = '>='
less_than_equal = '<='


class DateFilter(BaseModel):
date: datetime = Field(description='A datetime to filter on')
comparison_operator: ComparisonOperator = Field(
description='Comparison operator for date filter'
)


class SearchFilters(BaseModel):
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)


def search_filter_query_constructor(filters: SearchFilters) -> tuple[LiteralString, dict[str, Any]]:
filter_query: LiteralString = ''
filter_params: dict[str, Any] = {}

if filters.valid_at is not None:
valid_at_filter = 'AND ('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
filter_params['valid_at_' + str(j)] = date_filter.date

and_filters = [
'(r.valid_at ' + date_filter.comparison_operator.value + f' $valid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '

valid_at_filter += and_filter_query

if i == len(or_list) - 1:
valid_at_filter += ')'
else:
valid_at_filter += ' OR '

filter_query += valid_at_filter

if filters.invalid_at is not None:
invalid_at_filter = 'AND ('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
filter_params['invalid_at_' + str(j)] = date_filter.date

and_filters = [
'(r.invalid_at ' + date_filter.comparison_operator.value + f' $invalid_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '

invalid_at_filter += and_filter_query

if i == len(or_list) - 1:
invalid_at_filter += ')'
else:
invalid_at_filter += ' OR '

filter_query += invalid_at_filter

if filters.created_at is not None:
created_at_filter = 'AND ('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
filter_params['created_at_' + str(j)] = date_filter.date

and_filters = [
'(r.created_at ' + date_filter.comparison_operator.value + f' $created_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '

created_at_filter += and_filter_query

if i == len(or_list) - 1:
created_at_filter += ')'
else:
created_at_filter += ' OR '

filter_query += created_at_filter

if filters.expired_at is not None:
expired_at_filter = 'AND ('
for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list):
filter_params['expired_at_' + str(j)] = date_filter.date

and_filters = [
'(r.expired_at ' + date_filter.comparison_operator.value + f' $expired_at_{j})'
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filter_query) - 1:
and_filter_query += ' AND '

expired_at_filter += and_filter_query

if i == len(or_list) - 1:
expired_at_filter += ')'
else:
expired_at_filter += ' OR '

filter_query += expired_at_filter

return filter_query, filter_params
Loading