diff --git a/js/sdk/__tests__/ConversationsIntegrationUser.test.ts b/js/sdk/__tests__/ConversationsIntegrationUser.test.ts new file mode 100644 index 000000000..caecdb6a9 --- /dev/null +++ b/js/sdk/__tests__/ConversationsIntegrationUser.test.ts @@ -0,0 +1,260 @@ +import { r2rClient } from "../src/index"; +import { describe, test, beforeAll, expect } from "@jest/globals"; + +const baseUrl = "http://localhost:7272"; + +describe("r2rClient V3 Collections Integration Tests", () => { + let client: r2rClient; + let user1Client: r2rClient; + let user2Client: r2rClient; + let user1Id: string; + let user2Id: string; + let conversationId: string; + let user1ConversationId: string; + let user2ConversationId: string; + + beforeAll(async () => { + client = new r2rClient(baseUrl); + user1Client = new r2rClient(baseUrl); + user2Client = new r2rClient(baseUrl); + + await client.users.login({ + email: "admin@example.com", + password: "change_me_immediately", + }); + }); + + test("Register user 1", async () => { + const response = await client.users.create({ + email: "user1@example.com", + password: "change_me_immediately", + }); + + user1Id = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.is_superuser).toBe(false); + expect(response.results.name).toBe(null); + }); + + test("Login as a user 1", async () => { + const response = await user1Client.users.login({ + email: "user1@example.com", + password: "change_me_immediately", + }); + expect(response.results).toBeDefined(); + }); + + test("Register user 2", async () => { + const response = await client.users.create({ + email: "user2@example.com", + password: "change_me_immediately", + }); + + user2Id = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.is_superuser).toBe(false); + expect(response.results.name).toBe(null); + }); + + test("Login as a user 2", async () => { + const response = await user2Client.users.login({ + email: "user2@example.com", + password: "change_me_immediately", + }); + expect(response.results).toBeDefined(); + }); + + test("Get the health of the system", async () => { + const response = await client.system.health(); + expect(response.results).toBeDefined(); + }); + + test("Get the health of the system as user 1", async () => { + const response = await user1Client.system.health(); + expect(response.results).toBeDefined(); + }); + + test("Get the health of the system as user 2", async () => { + const response = await user2Client.system.health(); + expect(response.results).toBeDefined(); + }); + + test("List all conversations", async () => { + const response = await client.conversations.list(); + + expect(response.results).toBeDefined(); + expect(response.results).toEqual([]); + expect(response.total_entries).toBe(0); + }); + + test("List all conversations as user 1", async () => { + const response = await user1Client.conversations.list(); + + expect(response.results).toBeDefined(); + expect(response.results).toEqual([]); + expect(response.total_entries).toBe(0); + }); + + test("List all conversations as user 2", async () => { + const response = await user2Client.conversations.list(); + + expect(response.results).toBeDefined(); + expect(response.results).toEqual([]); + expect(response.total_entries).toBe(0); + }); + + test("Create a conversation with a name", async () => { + const response = await client.conversations.create({ + name: "Test Conversation", + }); + conversationId = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("Test Conversation"); + }); + + test("Create a conversation with a name as user 1", async () => { + const response = await user1Client.conversations.create({ + name: "User 1 Conversation", + }); + user1ConversationId = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("User 1 Conversation"); + }); + + test("Create a conversation with a name as user 2", async () => { + const response = await user2Client.conversations.create({ + name: "User 2 Conversation", + }); + user2ConversationId = response.results.id; + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("User 2 Conversation"); + }); + + test("Update a conversation name", async () => { + const response = await client.conversations.update({ + id: conversationId, + name: "Updated Name", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("Updated Name"); + }); + + test("Update a conversation name as user 1", async () => { + const response = await user1Client.conversations.update({ + id: user1ConversationId, + name: "User 1 Updated Name", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("User 1 Updated Name"); + }); + + test("Update a conversation name as user 2", async () => { + const response = await user2Client.conversations.update({ + id: user2ConversationId, + name: "User 2 Updated Name", + }); + expect(response.results).toBeDefined(); + expect(response.results.name).toBe("User 2 Updated Name"); + }); + + test("Add a message to a conversation", async () => { + const response = await client.conversations.addMessage({ + id: conversationId, + content: "Hello, world!", + role: "user", + }); + expect(response.results).toBeDefined(); + }); + + test("Add a message to a conversation as user 1", async () => { + const response = await user1Client.conversations.addMessage({ + id: user1ConversationId, + content: "Hello, world!", + role: "user", + }); + expect(response.results).toBeDefined(); + }); + + test("Add a message to a conversation as user 2", async () => { + const response = await user2Client.conversations.addMessage({ + id: user2ConversationId, + content: "Hello, world!", + role: "user", + }); + expect(response.results).toBeDefined(); + }); + + test("User 1 should not be able to see user 2's conversation", async () => { + await expect( + user1Client.conversations.retrieve({ id: user2ConversationId }), + ).rejects.toThrow(/Status 404/); + }); + + test("User 2 should not be able to see user 1's conversation", async () => { + await expect( + user2Client.conversations.retrieve({ id: user1ConversationId }), + ).rejects.toThrow(/Status 404/); + }); + + test("User 1 should not see user 2's conversation when listing all conversations", async () => { + const response = await user1Client.conversations.list(); + expect(response.results).toHaveLength(1); + }); + + test("User 2 should not see user 1's conversation when listing all conversations", async () => { + const response = await user2Client.conversations.list(); + expect(response.results).toHaveLength(1); + }); + + test("The super user should see all conversations when listing all conversations", async () => { + const response = await client.conversations.list(); + expect(response.results).toHaveLength(3); + }); + + test("Delete a conversation", async () => { + const response = await client.conversations.delete({ id: conversationId }); + expect(response.results).toBeDefined(); + }); + + test("User 1 should not be able to delete user 2's conversation", async () => { + await expect( + user1Client.conversations.delete({ id: user2ConversationId }), + ).rejects.toThrow(/Status 404/); + }); + + test("User 2 should not be able to delete user 1's conversation", async () => { + await expect( + user2Client.conversations.delete({ id: user1ConversationId }), + ).rejects.toThrow(/Status 404/); + }); + + test("Delete a conversation as user 1", async () => { + const response = await user1Client.conversations.delete({ + id: user1ConversationId, + }); + expect(response.results).toBeDefined(); + }); + + test("Super user should be able to delete any conversation", async () => { + const response = await client.conversations.delete({ + id: user2ConversationId, + }); + expect(response.results).toBeDefined(); + }); + + test("Delete user 1", async () => { + const response = await client.users.delete({ + id: user1Id, + password: "change_me_immediately", + }); + expect(response.results).toBeDefined(); + }); + + test("Delete user 2", async () => { + const response = await client.users.delete({ + id: user2Id, + password: "change_me_immediately", + }); + expect(response.results).toBeDefined(); + }); +}); diff --git a/py/core/database/conversations.py b/py/core/database/conversations.py index 7bb1d4833..a5cdaface 100644 --- a/py/core/database/conversations.py +++ b/py/core/database/conversations.py @@ -72,82 +72,64 @@ async def create_conversation( detail=f"Failed to create conversation: {str(e)}", ) from e - async def verify_conversation_access( - self, conversation_id: UUID, user_id: UUID - ) -> bool: - query = f""" - SELECT 1 FROM {self._get_table_name("conversations")} - WHERE id = $1 AND (user_id IS NULL OR user_id = $2) - """ - row = await self.connection_manager.fetchrow_query( - query, [conversation_id, user_id] - ) - return row is not None - async def get_conversations_overview( self, offset: int, limit: int, - user_ids: Optional[UUID | list[UUID]] = None, + filter_user_ids: Optional[list[UUID]] = None, conversation_ids: Optional[list[UUID]] = None, ) -> dict[str, Any]: - # Construct conditions conditions = [] params: list = [] param_index = 1 - if user_ids is not None: - if isinstance(user_ids, UUID): - conditions.append(f"user_id = ${param_index}") - params.append(user_ids) - param_index += 1 - else: - # user_ids is a list of UUIDs - placeholders = ", ".join( - f"${i+param_index}" for i in range(len(user_ids)) - ) - conditions.append( - f"user_id = ANY(ARRAY[{placeholders}]::uuid[])" + if filter_user_ids: + conditions.append( + f""" + c.user_id IN ( + SELECT id + FROM {self.project_name}.users + WHERE id = ANY(${param_index}) ) - params.extend(user_ids) - param_index += len(user_ids) + """ + ) + params.append(filter_user_ids) + param_index += 1 if conversation_ids: - placeholders = ", ".join( - f"${i+param_index}" for i in range(len(conversation_ids)) - ) - conditions.append(f"id = ANY(ARRAY[{placeholders}]::uuid[])") - params.extend(conversation_ids) - param_index += len(conversation_ids) + conditions.append(f"c.id = ANY(${param_index})") + params.append(conversation_ids) + param_index += 1 where_clause = ( "WHERE " + " AND ".join(conditions) if conditions else "" ) - limit_clause = "" - if limit != -1: - limit_clause = f"LIMIT ${param_index}" - params.append(limit) - param_index += 1 - - offset_clause = f"OFFSET ${param_index}" - params.append(offset) - query = f""" WITH conversation_overview AS ( - SELECT id, extract(epoch from created_at) as created_at_epoch, user_id, name - FROM {self._get_table_name("conversations")} + SELECT c.id, + extract(epoch from c.created_at) as created_at_epoch, + c.user_id, + c.name + FROM {self._get_table_name("conversations")} c {where_clause} ), counted_overview AS ( SELECT *, - COUNT(*) OVER() AS total_entries + COUNT(*) OVER() AS total_entries FROM conversation_overview ) SELECT * FROM counted_overview ORDER BY created_at_epoch DESC - {limit_clause} {offset_clause} + OFFSET ${param_index} """ + params.append(offset) + param_index += 1 + + if limit != -1: + query += f" LIMIT ${param_index}" + params.append(limit) + results = await self.connection_manager.fetch_query(query, params) if not results: @@ -244,7 +226,8 @@ async def edit_message( row = await self.connection_manager.fetchrow_query(query, [message_id]) if not row: raise R2RException( - status_code=404, message=f"Message {message_id} not found." + status_code=404, + message=f"Message {message_id} not found.", ) old_content = json.loads(row["content"]) @@ -335,13 +318,33 @@ async def update_message_metadata( ) async def get_conversation( - self, conversation_id: UUID + self, + conversation_id: UUID, + filter_user_ids: Optional[list[UUID]] = None, ) -> list[MessageResponse]: - # Check conversation - conv_query = f"SELECT extract(epoch from created_at) AS created_at_epoch FROM {self._get_table_name('conversations')} WHERE id = $1" - conv_row = await self.connection_manager.fetchrow_query( - conv_query, [conversation_id] - ) + conditions = ["c.id = $1"] + params: list = [conversation_id] + + if filter_user_ids: + param_index = 2 + conditions.append( + f""" + c.user_id IN ( + SELECT id + FROM {self.project_name}.users + WHERE id = ANY(${param_index}) + ) + """ + ) + params.append(filter_user_ids) + + query = f""" + SELECT c.id, extract(epoch from c.created_at) AS created_at_epoch + FROM {self._get_table_name('conversations')} c + WHERE {' AND '.join(conditions)} + """ + + conv_row = await self.connection_manager.fetchrow_query(query, params) if not conv_row: raise R2RException( status_code=404, @@ -403,11 +406,34 @@ async def update_conversation( detail=f"Failed to update conversation: {str(e)}", ) from e - async def delete_conversation(self, conversation_id: UUID) -> None: - # Check if conversation exists - conv_query = f"SELECT 1 FROM {self._get_table_name('conversations')} WHERE id = $1" + async def delete_conversation( + self, + conversation_id: UUID, + filter_user_ids: Optional[list[UUID]] = None, + ) -> None: + conditions = ["c.id = $1"] + params: list = [conversation_id] + + if filter_user_ids: + param_index = 2 + conditions.append( + f""" + c.user_id IN ( + SELECT id + FROM {self.project_name}.users + WHERE id = ANY(${param_index}) + ) + """ + ) + params.append(filter_user_ids) + + conv_query = f""" + SELECT 1 + FROM {self._get_table_name('conversations')} c + WHERE {' AND '.join(conditions)} + """ conv_row = await self.connection_manager.fetchrow_query( - conv_query, [conversation_id] + conv_query, params ) if not conv_row: raise R2RException( diff --git a/py/core/main/api/v3/conversations_router.py b/py/core/main/api/v3/conversations_router.py index ad6e393af..47b2351b0 100644 --- a/py/core/main/api/v3/conversations_router.py +++ b/py/core/main/api/v3/conversations_router.py @@ -186,15 +186,20 @@ async def list_conversations( This endpoint returns a paginated list of conversations for the authenticated user. """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + conversation_uuids = [ UUID(conversation_id) for conversation_id in ids ] conversations_response = ( await self.services.management.conversations_overview( - conversation_ids=conversation_uuids, offset=offset, limit=limit, + conversation_ids=conversation_uuids, + user_ids=requesting_user_id, ) ) return conversations_response["results"], { # type: ignore @@ -272,8 +277,13 @@ async def get_conversation( This endpoint retrieves detailed information about a single conversation identified by its UUID. """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + conversation = await self.services.management.get_conversation( - conversation_id=id + conversation_id=id, + user_ids=requesting_user_id, ) return conversation @@ -430,8 +440,13 @@ async def delete_conversation( This endpoint deletes a conversation identified by its UUID. """ + requesting_user_id = ( + None if auth_user.is_superuser else [auth_user.id] + ) + await self.services.management.delete_conversation( - conversation_id=id + conversation_id=id, + user_ids=requesting_user_id, ) return GenericBooleanResponse(success=True) # type: ignore diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 494f44637..be3498d78 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -208,213 +208,6 @@ def transform_chunk_id_to_id( "deleted_document_ids": [str(d) for d in docs_to_delete], } - # @telemetry_event("Delete") - # async def delete( - # self, - # filters: dict[str, Any], - # *args, - # **kwargs, - # ): - # """ - # Takes a list of filters like - # "{key: {operator: value}, key: {operator: value}, ...}" - # and deletes entries matching the given filters from both vector and relational databases. - - # NOTE: This method is not atomic and may result in orphaned entries in the documents overview table. - # NOTE: This method assumes that filters delete entire contents of any touched documents. - # """ - # ### TODO - FIX THIS, ENSURE THAT DOCUMENTS OVERVIEW IS CLEARED - - # def validate_filters(filters: dict[str, Any]) -> None: - # ALLOWED_FILTERS = { - # "id", - # "collection_ids", - # "chunk_id", - # # TODO - Modify these checks such that they can be used PROPERLY for nested filters - # "$and", - # "$or", - # } - - # if not filters: - # raise R2RException( - # status_code=422, message="No filters provided" - # ) - - # for field in filters: - # if field not in ALLOWED_FILTERS: - # raise R2RException( - # status_code=422, - # message=f"Invalid filter field: {field}", - # ) - - # for field in ["document_id", "owner_id", "chunk_id"]: - # if field in filters: - # op = next(iter(filters[field].keys())) - # try: - # validate_uuid(filters[field][op]) - # except ValueError: - # raise R2RException( - # status_code=422, - # message=f"Invalid UUID: {filters[field][op]}", - # ) - - # if "collection_ids" in filters: - # op = next(iter(filters["collection_ids"].keys())) - # for id_str in filters["collection_ids"][op]: - # try: - # validate_uuid(id_str) - # except ValueError: - # raise R2RException( - # status_code=422, message=f"Invalid UUID: {id_str}" - # ) - - # validate_filters(filters) - - # logger.info(f"Deleting entries with filters: {filters}") - - # try: - - # def transform_chunk_id_to_id( - # filters: dict[str, Any] - # ) -> dict[str, Any]: - # if isinstance(filters, dict): - # transformed = {} - # for key, value in filters.items(): - # if key == "chunk_id": - # transformed["id"] = value - # elif key in ["$and", "$or"]: - # transformed[key] = [ - # transform_chunk_id_to_id(item) - # for item in value - # ] - # else: - # transformed[key] = transform_chunk_id_to_id(value) - # return transformed - # return filters - - # filters_xf = transform_chunk_id_to_id(copy(filters)) - - # await self.providers.database.chunks_handler.delete(filters) - - # vector_delete_results = ( - # await self.providers.database.chunks_handler.delete(filters_xf) - # ) - # except Exception as e: - # logger.error(f"Error deleting from vector database: {e}") - # vector_delete_results = {} - - # document_ids_to_purge: set[UUID] = set() - # if vector_delete_results: - # document_ids_to_purge.update( - # UUID(result.get("document_id")) - # for result in vector_delete_results.values() - # if result.get("document_id") - # ) - - # # TODO: This might be appropriate to move elsewhere and revisit filter logic in other methods - # def extract_filters(filters: dict[str, Any]) -> dict[str, list[str]]: - # relational_filters: dict = {} - - # def process_filter(filter_dict: dict[str, Any]): - # if "document_id" in filter_dict: - # relational_filters.setdefault( - # "filter_document_ids", [] - # ).append(filter_dict["document_id"]["$eq"]) - # if "owner_id" in filter_dict: - # relational_filters.setdefault( - # "filter_user_ids", [] - # ).append(filter_dict["owner_id"]["$eq"]) - # if "collection_ids" in filter_dict: - # relational_filters.setdefault( - # "filter_collection_ids", [] - # ).extend(filter_dict["collection_ids"]["$in"]) - - # # Handle nested conditions - # if "$and" in filters: - # for condition in filters["$and"]: - # process_filter(condition) - # elif "$or" in filters: - # for condition in filters["$or"]: - # process_filter(condition) - # else: - # process_filter(filters) - - # return relational_filters - - # relational_filters = extract_filters(filters) - # if relational_filters: - # try: - # documents_overview = ( - # await self.providers.database.documents_handler.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - # offset=0, - # limit=1000, - # **relational_filters, # type: ignore - # ) - # )["results"] - # except Exception as e: - # logger.error( - # f"Error fetching documents from relational database: {e}" - # ) - # documents_overview = [] - - # if documents_overview: - # document_ids_to_purge.update( - # doc.id for doc in documents_overview - # ) - - # if not document_ids_to_purge: - # raise R2RException( - # status_code=404, message="No entries found for deletion." - # ) - - # for document_id in document_ids_to_purge: - # remaining_chunks = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - # document_id=document_id, - # offset=0, - # limit=1000, - # ) - # if remaining_chunks["total_entries"] == 0: - # try: - # await self.providers.database.chunks_handler.delete( - # {"document_id": {"$eq": document_id}} - # ) - # logger.info( - # f"Deleted document ID {document_id} from documents_overview." - # ) - # except Exception as e: - # logger.error( - # f"Error deleting document ID {document_id} from documents_overview: {e}" - # ) - # await self.providers.database.graphs_handler.entities.delete( - # parent_id=document_id, - # store_type="documents", # type: ignore - # ) - # await self.providers.database.graphs_handler.relationships.delete( - # parent_id=document_id, - # store_type="documents", # type: ignore - # ) - # await self.providers.database.documents_handler.delete( - # document_id=document_id - # ) - - # collections = await self.providers.database.collections_handler.get_collections_overview( - # offset=0, limit=1000, filter_document_ids=[document_id] - # ) - # # TODO - Loop over all collections - # for collection in collections["results"]: - # await self.providers.database.documents_handler.set_workflow_status( - # id=collection.id, - # status_type="graph_sync_status", - # status=KGEnrichmentStatus.OUTDATED, - # ) - # await self.providers.database.documents_handler.set_workflow_status( - # id=collection.id, - # status_type="graph_cluster_status", - # status=KGEnrichmentStatus.OUTDATED, - # ) - - # return None - @telemetry_event("DownloadFile") async def download_file( self, document_id: UUID @@ -433,8 +226,6 @@ async def documents_overview( user_ids: Optional[list[UUID]] = None, collection_ids: Optional[list[UUID]] = None, document_ids: Optional[list[UUID]] = None, - *args: Any, - **kwargs: Any, ): return await self.providers.database.documents_handler.get_documents_overview( offset=offset, @@ -451,8 +242,6 @@ async def list_document_chunks( offset: int, limit: int, include_vectors: bool = False, - *args, - **kwargs, ): return ( await self.providers.database.chunks_handler.list_document_chunks( @@ -804,18 +593,11 @@ async def delete_prompt(self, name: str) -> dict: async def get_conversation( self, conversation_id: UUID, - auth_user=None, + user_ids: Optional[list[UUID]] = None, ) -> Tuple[str, list[Message], list[dict]]: - return await self.providers.database.conversations_handler.get_conversation( # type: ignore - conversation_id=conversation_id - ) - - async def verify_conversation_access( - self, conversation_id: UUID, user_id: UUID - ) -> bool: - return await self.providers.database.conversations_handler.verify_conversation_access( + return await self.providers.database.conversations_handler.get_conversation( conversation_id=conversation_id, - user_id=user_id, + filter_user_ids=user_ids, ) @telemetry_event("CreateConversation") @@ -835,13 +617,12 @@ async def conversations_overview( offset: int, limit: int, conversation_ids: Optional[list[UUID]] = None, - user_ids: Optional[UUID | list[UUID]] = None, - auth_user=None, + user_ids: Optional[list[UUID]] = None, ) -> dict[str, list[dict] | int]: return await self.providers.database.conversations_handler.get_conversations_overview( offset=offset, limit=limit, - user_ids=user_ids, + filter_user_ids=user_ids, conversation_ids=conversation_ids, ) @@ -852,7 +633,6 @@ async def add_message( content: Message, parent_id: Optional[UUID] = None, metadata: Optional[dict] = None, - auth_user=None, ) -> str: return await self.providers.database.conversations_handler.add_message( conversation_id=conversation_id, @@ -867,7 +647,6 @@ async def edit_message( message_id: UUID, new_content: Optional[str] = None, additional_metadata: Optional[dict] = None, - auth_user=None, ) -> dict[str, Any]: return ( await self.providers.database.conversations_handler.edit_message( @@ -886,9 +665,14 @@ async def update_conversation( ) @telemetry_event("DeleteConversation") - async def delete_conversation(self, conversation_id: UUID) -> None: + async def delete_conversation( + self, + conversation_id: UUID, + user_ids: Optional[list[UUID]] = None, + ) -> None: await self.providers.database.conversations_handler.delete_conversation( - conversation_id=conversation_id + conversation_id=conversation_id, + filter_user_ids=user_ids, ) async def get_user_max_documents(self, user_id: UUID) -> int: diff --git a/py/sdk/v3/conversations.py b/py/sdk/v3/conversations.py index c4b4c3dc6..0788c25b0 100644 --- a/py/sdk/v3/conversations.py +++ b/py/sdk/v3/conversations.py @@ -184,7 +184,7 @@ async def update_message( Returns: dict: Result of the operation, including the new message ID and branch ID """ - data = {"content": content} + data: dict[str, Any] = {"content": content} if metadata: data["metadata"] = metadata return await self.client._make_request(