Skip to content

Commit

Permalink
Add response model
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Feb 4, 2025
1 parent 1de1fd4 commit a1a49be
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 16 deletions.
7 changes: 5 additions & 2 deletions py/core/base/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
WrappedBooleanResponse,
WrappedGenericMessageResponse,
)
from shared.api.models.graph.responses import ( # TODO: Need to review anything above this
from shared.api.models.graph.responses import (
Community,
Entity,
GraphResponse,
Relationship,
Traversal,
WrappedCommunitiesResponse,
WrappedCommunityResponse,
WrappedEntitiesResponse,
Expand All @@ -23,6 +24,7 @@
WrappedGraphsResponse,
WrappedRelationshipResponse,
WrappedRelationshipsResponse,
WrappedTraversalResponse,
)
from shared.api.models.ingestion.responses import (
IngestionResponse,
Expand Down Expand Up @@ -101,10 +103,11 @@
"WrappedRelationshipsResponse",
"WrappedCommunityResponse",
"WrappedCommunitiesResponse",
# TODO: Need to review anything above this
"GraphResponse",
"Traversal",
"WrappedGraphResponse",
"WrappedGraphsResponse",
"WrappedTraversalResponse",
# Management Responses
"PromptResponse",
"ServerStats",
Expand Down
80 changes: 76 additions & 4 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
WrappedCommunityResponse,
WrappedEntitiesResponse,
WrappedEntityResponse,
WrappedGenericMessageResponse,
WrappedGraphResponse,
WrappedGraphsResponse,
WrappedRelationshipResponse,
WrappedRelationshipsResponse,
WrappedTraversalResponse,
)
from core.utils import (
generate_default_user_collection_id,
Expand Down Expand Up @@ -2168,7 +2168,48 @@ async def pull(
@self.router.get(
"/graphs/{collection_id}/dijkstra",
dependencies=[Depends(self.rate_limit_dependency)],
summary="Dijsktra",
summary="Run Dijkstra's algorithm on the graph, finding the shortest path between two entities.",
openapi_extra={
"x-codeSamples": [
{
"lang": "Python",
"source": textwrap.dedent(
"""
from r2r import R2RClient
client = R2RClient()
# when using auth, do client.login(...)
response = client.graphs.dijkstra(
collection_id="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09",
source_id="8883f201-2a31-5ae1-b490-3c99210d8950",
target_id="2883ce39-39fb-59e5-ab8d-3e1b6df010ce"
)
"""
),
},
{
"lang": "JavaScript",
"source": textwrap.dedent(
"""
const { r2rClient } = require("r2r-js");
const client = new r2rClient();
async function main() {
const response = await client.graphs.dijkstra({
collectionId="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09",
sourceId="8883f201-2a31-5ae1-b490-3c99210d8950",
targetId="2883ce39-39fb-59e5-ab8d-3e1b6df010ce"
});
}
main();
"""
),
},
]
},
)
@self.base_endpoint
async def dijkstra(
Expand All @@ -2182,8 +2223,39 @@ async def dijkstra(
..., description="The ID of the target entity."
),
auth_user=Depends(self.providers.auth.auth_wrapper()),
):
# TODO: Auth
) -> WrappedTraversalResponse:
# Check user permissions for graph
collections_overview_response = (
await self.services.management.collections_overview(
user_ids=[auth_user.id],
collection_ids=[collection_id],
offset=0,
limit=1,
)
)["results"]
if len(collections_overview_response) == 0:
raise R2RException("Collection not found.", 404)

if (
not auth_user.is_superuser
and collections_overview_response[0].owner_id != auth_user.id
):
raise R2RException("Only superusers can `pull` a graph.", 403)

if collection_id not in auth_user.collection_ids:
raise R2RException(
"The currently authenticated user does not have access to the collection associated with the given graph.",
403,
)

list_graphs_response = await self.services.graph.list_graphs(
graph_ids=[collection_id],
offset=0,
limit=1,
)
if len(list_graphs_response["results"]) == 0:
raise R2RException("Graph not found", 404)

return await self.services.graph.dijkstra( # type: ignore
graph_id=collection_id,
source_id=source_id,
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2961,7 +2961,7 @@ async def dijkstra_shortest_path(
return path

except R2RException:
return {"results": {"path": [], "total_cost": 0, "num_hops": 0}}
return {"path": [], "total_cost": 0, "num_hops": 0}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error finding path: {str(e)}"
Expand Down
3 changes: 2 additions & 1 deletion py/sdk/sync_methods/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
WrappedGraphsResponse,
WrappedRelationshipResponse,
WrappedRelationshipsResponse,
WrappedTraversalResponse,
)

_list = list # Required for type hinting since we have a list method
Expand Down Expand Up @@ -603,7 +604,7 @@ def dijkstra(
collection_id: str | UUID,
source_id: str | UUID,
target_id: str | UUID,
):
) -> WrappedTraversalResponse:
"""
Get the shortest path between two entities in a graph.
Expand Down
5 changes: 4 additions & 1 deletion py/shared/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
)
from shared.api.models.graph.responses import (
GraphResponse,
Traversal,
WrappedGraphResponse,
WrappedGraphsResponse,
WrappedTraversalResponse,
)
from shared.api.models.ingestion.responses import (
IngestionResponse,
Expand Down Expand Up @@ -72,10 +74,11 @@
"WrappedIngestionResponse",
"WrappedUpdateResponse",
"WrappedMetadataUpdateResponse",
# TODO: Need to review anything above this
"GraphResponse",
"Traversal",
"WrappedGraphResponse",
"WrappedGraphsResponse",
"WrappedTraversalResponse",
# Management Responses
"PromptResponse",
"ServerStats",
Expand Down
30 changes: 23 additions & 7 deletions py/shared/api/models/graph/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,6 @@
from shared.abstractions.graph import Community, Entity, Relationship
from shared.api.models.base import PaginatedR2RResult, R2RResults

WrappedEntityResponse = R2RResults[Entity]
WrappedEntitiesResponse = PaginatedR2RResult[list[Entity]]
WrappedRelationshipResponse = R2RResults[Relationship]
WrappedRelationshipsResponse = PaginatedR2RResult[list[Relationship]]
WrappedCommunityResponse = R2RResults[Community]
WrappedCommunitiesResponse = PaginatedR2RResult[list[Community]]


class GraphResponse(BaseModel):
id: UUID
Expand All @@ -26,6 +19,29 @@ class GraphResponse(BaseModel):
document_ids: list[UUID]


class Traversal(BaseModel):
type: str
id: UUID
name: Optional[str]


class TraversalResponse(BaseModel):
path: list[Traversal]
total_cost: float
num_hops: int


# Graph Responses
WrappedCommunityResponse = R2RResults[Community]
WrappedCommunitiesResponse = PaginatedR2RResult[list[Community]]

WrappedEntityResponse = R2RResults[Entity]
WrappedEntitiesResponse = PaginatedR2RResult[list[Entity]]

WrappedGraphResponse = R2RResults[GraphResponse]
WrappedGraphsResponse = PaginatedR2RResult[list[GraphResponse]]

WrappedRelationshipResponse = R2RResults[Relationship]
WrappedRelationshipsResponse = PaginatedR2RResult[list[Relationship]]

WrappedTraversalResponse = R2RResults[TraversalResponse]

0 comments on commit a1a49be

Please sign in to comment.