From a1a49beb0acc4fb5da04b695e8a8c05495e5ca7f Mon Sep 17 00:00:00 2001 From: NolanTrem <34580718+NolanTrem@users.noreply.github.com> Date: Tue, 4 Feb 2025 10:31:09 -0800 Subject: [PATCH] Add response model --- py/core/base/api/models/__init__.py | 7 ++- py/core/main/api/v3/graph_router.py | 80 +++++++++++++++++++++++-- py/core/providers/database/graphs.py | 2 +- py/sdk/sync_methods/graphs.py | 3 +- py/shared/api/models/__init__.py | 5 +- py/shared/api/models/graph/responses.py | 30 +++++++--- 6 files changed, 111 insertions(+), 16 deletions(-) diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index fcd969d07..f239ea244 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -10,11 +10,12 @@ WrappedBooleanResponse, WrappedGenericMessageResponse, ) -from shared.api.models.graph.responses import ( # TODO: Need to review anything above this +from shared.api.models.graph.responses import ( Community, Entity, GraphResponse, Relationship, + Traversal, WrappedCommunitiesResponse, WrappedCommunityResponse, WrappedEntitiesResponse, @@ -23,6 +24,7 @@ WrappedGraphsResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, + WrappedTraversalResponse, ) from shared.api.models.ingestion.responses import ( IngestionResponse, @@ -101,10 +103,11 @@ "WrappedRelationshipsResponse", "WrappedCommunityResponse", "WrappedCommunitiesResponse", - # TODO: Need to review anything above this "GraphResponse", + "Traversal", "WrappedGraphResponse", "WrappedGraphsResponse", + "WrappedTraversalResponse", # Management Responses "PromptResponse", "ServerStats", diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 43554005a..3a74056d9 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -16,11 +16,11 @@ WrappedCommunityResponse, WrappedEntitiesResponse, WrappedEntityResponse, - WrappedGenericMessageResponse, WrappedGraphResponse, WrappedGraphsResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, + WrappedTraversalResponse, ) from core.utils import ( generate_default_user_collection_id, @@ -2168,7 +2168,48 @@ async def pull( @self.router.get( "/graphs/{collection_id}/dijkstra", dependencies=[Depends(self.rate_limit_dependency)], - summary="Dijsktra", + summary="Run Dijkstra's algorithm on the graph, finding the shortest path between two entities.", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient() + # when using auth, do client.login(...) + + response = client.graphs.dijkstra( + collection_id="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09", + source_id="8883f201-2a31-5ae1-b490-3c99210d8950", + target_id="2883ce39-39fb-59e5-ab8d-3e1b6df010ce" + ) + """ + ), + }, + { + "lang": "JavaScript", + "source": textwrap.dedent( + """ + const { r2rClient } = require("r2r-js"); + + const client = new r2rClient(); + + async function main() { + const response = await client.graphs.dijkstra({ + collectionId="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09", + sourceId="8883f201-2a31-5ae1-b490-3c99210d8950", + targetId="2883ce39-39fb-59e5-ab8d-3e1b6df010ce" + }); + } + + main(); + """ + ), + }, + ] + }, ) @self.base_endpoint async def dijkstra( @@ -2182,8 +2223,39 @@ async def dijkstra( ..., description="The ID of the target entity." ), auth_user=Depends(self.providers.auth.auth_wrapper()), - ): - # TODO: Auth + ) -> WrappedTraversalResponse: + # Check user permissions for graph + collections_overview_response = ( + await self.services.management.collections_overview( + user_ids=[auth_user.id], + collection_ids=[collection_id], + offset=0, + limit=1, + ) + )["results"] + if len(collections_overview_response) == 0: + raise R2RException("Collection not found.", 404) + + if ( + not auth_user.is_superuser + and collections_overview_response[0].owner_id != auth_user.id + ): + raise R2RException("Only superusers can `pull` a graph.", 403) + + if collection_id not in auth_user.collection_ids: + raise R2RException( + "The currently authenticated user does not have access to the collection associated with the given graph.", + 403, + ) + + list_graphs_response = await self.services.graph.list_graphs( + graph_ids=[collection_id], + offset=0, + limit=1, + ) + if len(list_graphs_response["results"]) == 0: + raise R2RException("Graph not found", 404) + return await self.services.graph.dijkstra( # type: ignore graph_id=collection_id, source_id=source_id, diff --git a/py/core/providers/database/graphs.py b/py/core/providers/database/graphs.py index 6bcce926d..d40d48a06 100644 --- a/py/core/providers/database/graphs.py +++ b/py/core/providers/database/graphs.py @@ -2961,7 +2961,7 @@ async def dijkstra_shortest_path( return path except R2RException: - return {"results": {"path": [], "total_cost": 0, "num_hops": 0}} + return {"path": [], "total_cost": 0, "num_hops": 0} except Exception as e: raise HTTPException( status_code=500, detail=f"Error finding path: {str(e)}" diff --git a/py/sdk/sync_methods/graphs.py b/py/sdk/sync_methods/graphs.py index 77ad02e4d..e0b1eea0d 100644 --- a/py/sdk/sync_methods/graphs.py +++ b/py/sdk/sync_methods/graphs.py @@ -11,6 +11,7 @@ WrappedGraphsResponse, WrappedRelationshipResponse, WrappedRelationshipsResponse, + WrappedTraversalResponse, ) _list = list # Required for type hinting since we have a list method @@ -603,7 +604,7 @@ def dijkstra( collection_id: str | UUID, source_id: str | UUID, target_id: str | UUID, - ): + ) -> WrappedTraversalResponse: """ Get the shortest path between two entities in a graph. diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index 53aa4c699..7963bf8d7 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -12,8 +12,10 @@ ) from shared.api.models.graph.responses import ( GraphResponse, + Traversal, WrappedGraphResponse, WrappedGraphsResponse, + WrappedTraversalResponse, ) from shared.api.models.ingestion.responses import ( IngestionResponse, @@ -72,10 +74,11 @@ "WrappedIngestionResponse", "WrappedUpdateResponse", "WrappedMetadataUpdateResponse", - # TODO: Need to review anything above this "GraphResponse", + "Traversal", "WrappedGraphResponse", "WrappedGraphsResponse", + "WrappedTraversalResponse", # Management Responses "PromptResponse", "ServerStats", diff --git a/py/shared/api/models/graph/responses.py b/py/shared/api/models/graph/responses.py index a22728330..94387ae32 100644 --- a/py/shared/api/models/graph/responses.py +++ b/py/shared/api/models/graph/responses.py @@ -7,13 +7,6 @@ from shared.abstractions.graph import Community, Entity, Relationship from shared.api.models.base import PaginatedR2RResult, R2RResults -WrappedEntityResponse = R2RResults[Entity] -WrappedEntitiesResponse = PaginatedR2RResult[list[Entity]] -WrappedRelationshipResponse = R2RResults[Relationship] -WrappedRelationshipsResponse = PaginatedR2RResult[list[Relationship]] -WrappedCommunityResponse = R2RResults[Community] -WrappedCommunitiesResponse = PaginatedR2RResult[list[Community]] - class GraphResponse(BaseModel): id: UUID @@ -26,6 +19,29 @@ class GraphResponse(BaseModel): document_ids: list[UUID] +class Traversal(BaseModel): + type: str + id: UUID + name: Optional[str] + + +class TraversalResponse(BaseModel): + path: list[Traversal] + total_cost: float + num_hops: int + + # Graph Responses +WrappedCommunityResponse = R2RResults[Community] +WrappedCommunitiesResponse = PaginatedR2RResult[list[Community]] + +WrappedEntityResponse = R2RResults[Entity] +WrappedEntitiesResponse = PaginatedR2RResult[list[Entity]] + WrappedGraphResponse = R2RResults[GraphResponse] WrappedGraphsResponse = PaginatedR2RResult[list[GraphResponse]] + +WrappedRelationshipResponse = R2RResults[Relationship] +WrappedRelationshipsResponse = PaginatedR2RResult[list[Relationship]] + +WrappedTraversalResponse = R2RResults[TraversalResponse]