From 63a79171459729d4b7805dd8f050577e6a65fa4d Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:09:49 -0800 Subject: [PATCH] up (#1664) * up * add ingestion settings too * up * up * up --- js/sdk/src/v3/clients/documents.ts | 4 + js/sdk/src/v3/clients/retrieval.ts | 12 ++ py/core/__init__.py | 1 + py/core/base/__init__.py | 2 + py/core/base/abstractions/__init__.py | 2 + py/core/base/providers/__init__.py | 8 +- py/core/base/providers/ingestion.py | 24 ++- py/core/main/api/v3/documents_router.py | 66 ++++++- py/core/main/api/v3/retrieval_router.py | 216 +++++++++++++++++---- py/core/main/services/retrieval_service.py | 1 + py/sdk/models.py | 3 + py/sdk/v3/documents.py | 11 +- py/sdk/v3/retrieval.py | 14 ++ py/shared/abstractions/__init__.py | 3 + py/shared/abstractions/document.py | 128 ++++++++++++ py/shared/abstractions/search.py | 33 ++++ 16 files changed, 481 insertions(+), 47 deletions(-) diff --git a/js/sdk/src/v3/clients/documents.ts b/js/sdk/src/v3/clients/documents.ts index ffc1171e4..f9b5ebd34 100644 --- a/js/sdk/src/v3/clients/documents.ts +++ b/js/sdk/src/v3/clients/documents.ts @@ -46,6 +46,7 @@ export class DocumentsClient { ingestionConfig?: Record; collectionIds?: string[]; runWithOrchestration?: boolean; + ingestionMode?: "hi-res" | "fast" | "custom"; }): Promise { const inputCount = [options.file, options.raw_text, options.chunks].filter( (x) => x !== undefined, @@ -128,6 +129,9 @@ export class DocumentsClient { String(options.runWithOrchestration), ); } + if (options.ingestionMode) { + formData.append("ingestion_mode", options.ingestionMode); + } formData.append("file_names", JSON.stringify(processedFiles)); diff --git a/js/sdk/src/v3/clients/retrieval.ts b/js/sdk/src/v3/clients/retrieval.ts index 405ae47b8..f758bc5ea 100644 --- a/js/sdk/src/v3/clients/retrieval.ts +++ b/js/sdk/src/v3/clients/retrieval.ts @@ -28,6 +28,7 @@ export class RetrievalClient { @feature("retrieval.search") async search(options: { query: string; + searchMode?: "advanced" | "basic" | "custom"; searchSettings?: SearchSettings | Record; }): Promise { const data = { @@ -35,6 +36,9 @@ export class RetrievalClient { ...(options.searchSettings && { search_settings: options.searchSettings, }), + ...(options.searchMode && { + search_mode: options.searchMode, + }), }; return await this.client.makeRequest("POST", "retrieval/search", { @@ -60,6 +64,7 @@ export class RetrievalClient { @feature("retrieval.rag") async rag(options: { query: string; + searchMode?: "advanced" | "basic" | "custom"; searchSettings?: SearchSettings | Record; ragGenerationConfig?: GenerationConfig | Record; taskPromptOverride?: string; @@ -67,6 +72,9 @@ export class RetrievalClient { }): Promise> { const data = { query: options.query, + ...(options.searchMode && { + search_mode: options.searchMode, + }), ...(options.searchSettings && { search_settings: options.searchSettings, }), @@ -155,6 +163,7 @@ export class RetrievalClient { @feature("retrieval.agent") async agent(options: { message: Message; + searchMode?: "advanced" | "basic" | "custom"; searchSettings?: SearchSettings | Record; ragGenerationConfig?: GenerationConfig | Record; taskPromptOverride?: string; @@ -164,6 +173,9 @@ export class RetrievalClient { }): Promise> { const data: Record = { message: options.message, + ...(options.searchMode && { + search_mode: options.searchMode, + }), ...(options.searchSettings && { search_settings: options.searchSettings, }), diff --git a/py/core/__init__.py b/py/core/__init__.py index 94c6dc9dc..18edeb3da 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -79,6 +79,7 @@ "GraphSearchSettings", "ChunkSearchResult", "SearchSettings", + "SearchMode", "HybridSearchSettings", # User abstractions "Token", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 1fbbafe96..e9973bda4 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -52,6 +52,7 @@ "ChunkSearchSettings", "ChunkSearchResult", "SearchSettings", + "SearchMode", "HybridSearchSettings", # User abstractions "Token", @@ -117,6 +118,7 @@ "EmbeddingConfig", "EmbeddingProvider", # Ingestion provider + "IngestionMode", "IngestionConfig", "IngestionProvider", "ChunkingStrategy", diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index edeb8b9d7..d71c267db 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -62,6 +62,7 @@ KGGlobalResult, KGRelationshipResult, KGSearchResultType, + SearchMode, SearchSettings, WebSearchResponse, ) @@ -133,6 +134,7 @@ "ChunkSearchSettings", "ChunkSearchResult", "SearchSettings", + "SearchMode", "HybridSearchSettings", # KG abstractions "KGCreationSettings", diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 1825ebcc1..9b2e4cdad 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -18,7 +18,12 @@ ) from .email import EmailConfig, EmailProvider from .embedding import EmbeddingConfig, EmbeddingProvider -from .ingestion import ChunkingStrategy, IngestionConfig, IngestionProvider +from .ingestion import ( + ChunkingStrategy, + IngestionConfig, + IngestionMode, + IngestionProvider, +) from .llm import CompletionConfig, CompletionProvider from .orchestration import OrchestrationConfig, OrchestrationProvider, Workflow @@ -31,6 +36,7 @@ "Provider", "ProviderConfig", # Ingestion provider + "IngestionMode", "IngestionConfig", "IngestionProvider", "ChunkingStrategy", diff --git a/py/core/base/providers/ingestion.py b/py/core/base/providers/ingestion.py index 2d6d9947b..a4d27b2de 100644 --- a/py/core/base/providers/ingestion.py +++ b/py/core/base/providers/ingestion.py @@ -1,7 +1,6 @@ import logging from abc import ABC from enum import Enum -from typing import Optional from core.base.abstractions import ChunkEnrichmentSettings @@ -34,6 +33,8 @@ class IngestionConfig(ProviderConfig): chunks_for_document_summary: int = 128 document_summary_model: str = "openai/gpt-4o-mini" + parser_overrides: dict[str, str] = {} + @property def supported_providers(self) -> list[str]: return ["r2r", "unstructured_local", "unstructured_api"] @@ -42,6 +43,21 @@ def validate_config(self) -> None: if self.provider not in self.supported_providers: raise ValueError(f"Provider {self.provider} is not supported.") + @classmethod + def get_default(cls, mode: str, app) -> "IngestionConfig": + """Return default ingestion configuration for a given mode.""" + if mode == "hi-res": + # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`. + return cls(app=app, parser_overrides={"pdf": "zerox"}) + # elif mode == "fast": + # # Skip summaries and other enrichment steps for speed. + # return cls( + # app=app, + # ) + else: + # For `custom` or any unrecognized mode, return a base config + return cls(app=app) + class IngestionProvider(Provider, ABC): @@ -66,3 +82,9 @@ class ChunkingStrategy(str, Enum): CHARACTER = "character" BASIC = "basic" BY_TITLE = "by_title" + + +class IngestionMode(str, Enum): + hi_res = "hi-res" + fast = "fast" + custom = "custom" diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index 4c2d2c13d..c6ba01cd0 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -12,6 +12,8 @@ from pydantic import Json from core.base import ( + IngestionConfig, + IngestionMode, R2RException, RunType, UnprocessedChunk, @@ -44,6 +46,18 @@ MAX_CHUNKS_PER_REQUEST = 1024 * 100 +def merge_ingestion_config( + base: IngestionConfig, overrides: IngestionConfig +) -> IngestionConfig: + base_dict = base.model_dump() + overrides_dict = overrides.model_dump(exclude_unset=True) + + for k, v in overrides_dict.items(): + base_dict[k] = v + + return IngestionConfig(**base_dict) + + class DocumentsRouter(BaseRouterV3): def __init__( self, @@ -106,6 +120,29 @@ def _register_workflows(self): }, ) + def _prepare_ingestion_config( + self, + ingestion_mode: IngestionMode, + ingestion_config: Optional[IngestionConfig], + ) -> IngestionConfig: + # If not custom, start from defaults + if ingestion_mode != IngestionMode.custom: + effective_config = IngestionConfig.get_default( + ingestion_mode.value, app=self.providers.auth.config.app + ) + if ingestion_config: + effective_config = merge_ingestion_config( + effective_config, ingestion_config + ) + else: + # custom mode + effective_config = ingestion_config or IngestionConfig( + app=self.providers.auth.config.app + ) + + effective_config.validate_config() + return effective_config + def _setup_routes(self): @self.router.post( "/documents", @@ -199,7 +236,18 @@ async def create_document( None, description="Metadata to associate with the document, such as title, description, or custom fields.", ), - ingestion_config: Optional[Json[dict]] = Form( + ingestion_mode: IngestionMode = Form( + default=IngestionMode.custom, + description=( + "Ingestion modes:\n" + "- `hi-res`: Thorough ingestion with full summaries and enrichment.\n" + "- `fast`: Quick ingestion with minimal enrichment and no summaries.\n" + "- `custom`: Full control via `ingestion_config`.\n\n" + "If `filters` or `limit` (in `ingestion_config`) are provided alongside `hi-res` or `fast`, " + "they will override the default settings for that mode." + ), + ), + ingestion_config: Optional[Json[IngestionConfig]] = Form( None, description="An optional dictionary to override the default chunking configuration for the ingestion process. If not provided, the system will use the default server-side chunking configuration.", ), @@ -210,14 +258,23 @@ async def create_document( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedIngestionResponse: """ - Creates a new Document object from an input file or text content. The document will be processed - to create chunks for vector indexing and search. + Creates a new Document object from an input file, text content, or chunks. The chosen `ingestion_mode` determines + how the ingestion process is configured: + + **Ingestion Modes:** + - `hi-res`: Comprehensive parsing and enrichment, including summaries and possibly more thorough parsing. + - `fast`: Speed-focused ingestion that skips certain enrichment steps like summaries. + - `custom`: Provide a full `ingestion_config` to customize the entire ingestion process. Either a file or text content must be provided, but not both. Documents are shared through `Collections` which allow for tightly specified cross-user interactions. The ingestion process runs asynchronously and its progress can be tracked using the returned task_id. """ + effective_ingestion_config = self._prepare_ingestion_config( + ingestion_mode=ingestion_mode, + ingestion_config=ingestion_config, + ) if not file and not raw_text and not chunks: raise R2RException( status_code=422, @@ -275,6 +332,7 @@ async def create_document( ], "metadata": metadata, # Base metadata for the document "user": auth_user.model_dump_json(), + "ingestion_config": effective_ingestion_config.model_dump(), } # TODO - Modify create_chunks so that we can add chunks to existing document @@ -347,7 +405,7 @@ async def create_document( "document_id": str(document_id), "collection_ids": collection_ids, "metadata": metadata, - "ingestion_config": ingestion_config, + "ingestion_config": effective_ingestion_config.model_dump(), "user": auth_user.model_dump_json(), "size_in_bytes": content_length, "is_update": False, diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 7f226df0c..b35829116 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -7,7 +7,13 @@ from fastapi import Body, Depends from fastapi.responses import StreamingResponse -from core.base import GenerationConfig, Message, R2RException, SearchSettings +from core.base import ( + GenerationConfig, + Message, + R2RException, + SearchMode, + SearchSettings, +) from core.base.api.models import ( WrappedAgentResponse, WrappedCompletionResponse, @@ -23,6 +29,22 @@ from .base_router import BaseRouterV3 +def merge_search_settings( + base: SearchSettings, overrides: SearchSettings +) -> SearchSettings: + # Convert both to dict + base_dict = base.model_dump() + overrides_dict = overrides.model_dump(exclude_unset=True) + + # Update base_dict with values from overrides_dict + # This ensures that any field set in overrides takes precedence + for k, v in overrides_dict.items(): + base_dict[k] = v + + # Construct a new SearchSettings from the merged dict + return SearchSettings(**base_dict) + + class RetrievalRouterV3(BaseRouterV3): def __init__( self, @@ -38,6 +60,36 @@ def __init__( def _register_workflows(self): pass + def _prepare_search_settings( + self, + auth_user: Any, + search_mode: SearchMode, + search_settings: Optional[SearchSettings], + ) -> SearchSettings: + """ + Prepare the effective search settings based on the provided search_mode, + optional user-overrides in search_settings, and applied filters. + """ + + if search_mode != SearchMode.custom: + # Start from mode defaults + effective_settings = SearchSettings.get_default(search_mode.value) + if search_settings: + # Merge user-provided overrides + effective_settings = merge_search_settings( + effective_settings, search_settings + ) + else: + # Custom mode: use provided settings or defaults + effective_settings = search_settings or SearchSettings() + + # Apply user-specific filters + effective_settings.filters = self._select_filters( + auth_user, effective_settings + ) + + return effective_settings + def _select_filters( self, auth_user: Any, @@ -91,20 +143,34 @@ def _setup_routes(self): from r2r import R2RClient client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + # if using auth, do client.login(...) - response =client.retrieval.search( + # Basic mode, no overrides + response = client.retrieval.search( query="Who is Aristotle?", - search_settings: { - filters: {"document_id": {"$eq": "3e157b3a-8469-51db-90d9-52e7d896b49b"}}, - use_semantic_search: true, - chunk_settings: { - limit: 20, # separate limit for chunk vs. graph - enabled: true - }, - graph_settings: { - enabled: true, - }, + search_mode="basic" + ) + + # Advanced mode with overrides + response = client.retrieval.search( + query="Who is Aristotle?", + search_mode="advanced", + search_settings={ + "filters": {"document_id": {"$eq": "3e157b3a-..."}}, + "limit": 5 + } + ) + + # Custom mode with full control + response = client.retrieval.search( + query="Who is Aristotle?", + search_mode="custom", + search_settings={ + "use_semantic_search": True, + "filters": {"category": {"$like": "%philosophy%"}}, + "limit": 20, + "chunk_settings": {"limit": 20}, + "graph_settings": {"enabled": True} } ) """ @@ -180,27 +246,68 @@ async def search_app( ..., description="Search query to find relevant documents", ), - search_settings: SearchSettings = Body( - default_factory=SearchSettings, - description="Settings for vector-based search", + search_mode: SearchMode = Body( + default=SearchMode.custom, + description=( + "Default value of `custom` allows full control over search settings.\n\n" + "Pre-configured search modes:\n" + "`basic`: A simple semantic-based search.\n" + "`advanced`: A more powerful hybrid search combining semantic and full-text.\n" + "`custom`: Full control via `search_settings`.\n\n" + "If `filters` or `limit` are provided alongside `basic` or `advanced`, " + "they will override the default settings for that mode." + ), + ), + search_settings: Optional[SearchSettings] = Body( + None, + description=( + "The search configuration object. If `search_mode` is `custom`, " + "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n" + "Common overrides include `filters` to narrow results and `limit` to control how many results are returned." + ), ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedSearchResponse: """ - Perform a search query on the vector database and knowledge graph and any other configured search engines. - - This endpoint allows for complex filtering of search results using PostgreSQL-based queries. - Filters can be applied to various fields such as document_id, and internal metadata values. - - Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. + Perform a search query against vector and/or graph-based databases. + + **Search Modes:** + - `basic`: Defaults to semantic search. Simple and easy to use. + - `advanced`: Combines semantic search with full-text search for more comprehensive results. + - `custom`: Complete control over how search is performed. Provide a full `SearchSettings` object. + + **Filters:** + Apply filters directly inside `search_settings.filters`. For example: + ```json + { + "filters": {"document_id": {"$eq": "3e157b3a-..."}} + } + ``` + Supported operators: `$eq`, `$neq`, `$gt`, `$gte`, `$lt`, `$lte`, `$like`, `$ilike`, `$in`, `$nin`. + + **Limit:** + Control how many results you get by specifying `limit` inside `search_settings`. For example: + ```json + { + "limit": 20 + } + ``` + + **Examples:** + - Using `basic` mode and no overrides: + Just specify `search_mode="basic"`. + - Using `advanced` mode and applying a filter: + Specify `search_mode="advanced"` and include `search_settings={"filters": {...}, "limit": 5}` to override defaults. + - Using `custom` mode: + Provide the entire `search_settings` to define your search exactly as you want it. """ - search_settings.filters = self._select_filters( - auth_user, search_settings - ) + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings + ) results = await self.services["retrieval"].search( query=query, - search_settings=search_settings, + search_settings=effective_settings, ) return results @@ -318,9 +425,25 @@ async def search_app( @self.base_endpoint async def rag_app( query: str = Body(...), - search_settings: SearchSettings = Body( - default_factory=SearchSettings, - description="Settings for vector-based search", + search_mode: SearchMode = Body( + default=SearchMode.custom, + description=( + "Default value of `custom` allows full control over search settings.\n\n" + "Pre-configured search modes:\n" + "`basic`: A simple semantic-based search.\n" + "`advanced`: A more powerful hybrid search combining semantic and full-text.\n" + "`custom`: Full control via `search_settings`.\n\n" + "If `filters` or `limit` are provided alongside `basic` or `advanced`, " + "they will override the default settings for that mode." + ), + ), + search_settings: Optional[SearchSettings] = Body( + None, + description=( + "The search configuration object. If `search_mode` is `custom`, " + "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n" + "Common overrides include `filters` to narrow results and `limit` to control how many results are returned." + ), ), rag_generation_config: GenerationConfig = Body( default_factory=GenerationConfig, @@ -346,13 +469,13 @@ async def rag_app( The generation process can be customized using the `rag_generation_config` parameter. """ - search_settings.filters = self._select_filters( - auth_user, search_settings + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings ) response = await self.services["retrieval"].rag( query=query, - search_settings=search_settings, + search_settings=effective_settings, rag_generation_config=rag_generation_config, task_prompt_override=task_prompt_override, include_title_if_available=include_title_if_available, @@ -494,9 +617,25 @@ async def agent_app( deprecated=True, description="List of messages (deprecated, use message instead)", ), - search_settings: SearchSettings = Body( - default_factory=SearchSettings, - description="Settings for vector-based search", + search_mode: SearchMode = Body( + default=SearchMode.custom, + description=( + "Default value of `custom` allows full control over search settings.\n\n" + "Pre-configured search modes:\n" + "`basic`: A simple semantic-based search.\n" + "`advanced`: A more powerful hybrid search combining semantic and full-text.\n" + "`custom`: Full control via `search_settings`.\n\n" + "If `filters` or `limit` are provided alongside `basic` or `advanced`, " + "they will override the default settings for that mode." + ), + ), + search_settings: Optional[SearchSettings] = Body( + None, + description=( + "The search configuration object. If `search_mode` is `custom`, " + "these settings are used as-is. For `basic` or `advanced`, these settings will override the default mode configuration.\n\n" + "Common overrides include `filters` to narrow results and `limit` to control how many results are returned." + ), ), rag_generation_config: GenerationConfig = Body( default_factory=GenerationConfig, @@ -552,16 +691,15 @@ async def agent_app( information, providing detailed, factual responses with proper attribution to source documents. """ - search_settings.filters = self._select_filters( - auth_user=auth_user, - search_settings=search_settings, + effective_settings = self._prepare_search_settings( + auth_user, search_mode, search_settings ) try: response = await self.services["retrieval"].agent( message=message, messages=messages, - search_settings=search_settings, + search_settings=effective_settings, rag_generation_config=rag_generation_config, task_prompt_override=task_prompt_override, include_title_if_available=include_title_if_available, diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 9d15b4a3c..3f9ef458d 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -15,6 +15,7 @@ Message, R2RException, RunManager, + SearchMode, SearchSettings, manage_run, to_async_generator, diff --git a/py/sdk/models.py b/py/sdk/models.py index 277518b92..9ad98d6e7 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -4,6 +4,7 @@ GraphSearchResult, GraphSearchSettings, HybridSearchSettings, + IngestionMode, KGCommunityResult, KGCreationSettings, KGEnrichmentSettings, @@ -17,6 +18,7 @@ MessageType, R2RException, R2RSerializable, + SearchMode, SearchSettings, Token, User, @@ -43,6 +45,7 @@ "Token", "ChunkSearchResult", "SearchSettings", + "SearchMode", "KGEntityDeduplicationSettings", "RAGResponse", "CombinedSearchResponse", diff --git a/py/sdk/v3/documents.py b/py/sdk/v3/documents.py index a0b4dff6d..b9ece87af 100644 --- a/py/sdk/v3/documents.py +++ b/py/sdk/v3/documents.py @@ -12,6 +12,8 @@ WrappedDocumentsResponse, ) +from ..models import IngestionMode + class DocumentsSDK: """ @@ -27,9 +29,10 @@ async def create( raw_text: Optional[str] = None, chunks: Optional[list[str]] = None, id: Optional[str | UUID] = None, + ingestion_mode: Optional[str] = None, collection_ids: Optional[list[str | UUID]] = None, metadata: Optional[dict] = None, - ingestion_config: Optional[dict] = None, + ingestion_config: Optional[dict | IngestionMode] = None, run_with_orchestration: Optional[bool] = True, ) -> WrappedIngestionResponse: """ @@ -65,13 +68,17 @@ async def create( if metadata: data["metadata"] = json.dumps(metadata) if ingestion_config: + if not isinstance(ingestion_config, dict): + ingestion_config = ingestion_config.model_dump() + ingestion_config["app"] = {} data["ingestion_config"] = json.dumps(ingestion_config) if collection_ids: collection_ids = [str(collection_id) for collection_id in collection_ids] # type: ignore data["collection_ids"] = json.dumps(collection_ids) if run_with_orchestration is not None: data["run_with_orchestration"] = str(run_with_orchestration) - + if ingestion_mode is not None: + data["ingestion_mode"] = ingestion_mode if file_path: # Create a new file instance that will remain open during the request file_instance = open(file_path, "rb") diff --git a/py/sdk/v3/retrieval.py b/py/sdk/v3/retrieval.py index ed3a1438a..bead7a729 100644 --- a/py/sdk/v3/retrieval.py +++ b/py/sdk/v3/retrieval.py @@ -6,6 +6,7 @@ GraphSearchSettings, Message, RAGResponse, + SearchMode, SearchSettings, ) @@ -21,6 +22,7 @@ def __init__(self, client): async def search( self, query: str, + search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, ) -> CombinedSearchResponse: """ @@ -33,6 +35,9 @@ async def search( Returns: CombinedSearchResponse: The search response. """ + if search_mode and not isinstance(search_mode, str): + search_mode = search_mode.value + if search_settings and not isinstance(search_settings, dict): search_settings = search_settings.model_dump() @@ -40,6 +45,9 @@ async def search( "query": query, "search_settings": search_settings, } + if search_mode: + data["search_mode"] = search_mode + return await self.client._make_request( "POST", "retrieval/search", @@ -91,6 +99,7 @@ async def rag( self, query: str, rag_generation_config: Optional[dict | GenerationConfig] = None, + search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, task_prompt_override: Optional[str] = None, include_title_if_available: Optional[bool] = False, @@ -122,6 +131,8 @@ async def rag( "task_prompt_override": task_prompt_override, "include_title_if_available": include_title_if_available, } + if search_mode: + data["search_mode"] = search_mode if rag_generation_config and rag_generation_config.get( # type: ignore "stream", False @@ -144,6 +155,7 @@ async def agent( self, message: Optional[dict | Message] = None, rag_generation_config: Optional[dict | GenerationConfig] = None, + search_mode: Optional[str | SearchMode] = "custom", search_settings: Optional[dict | SearchSettings] = None, task_prompt_override: Optional[str] = None, include_title_if_available: Optional[bool] = False, @@ -177,6 +189,8 @@ async def agent( "conversation_id": conversation_id, "branch_id": branch_id, } + if search_mode: + data["search_mode"] = search_mode if message: cast_message: Message = ( diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index 3fcaf4037..d6f4c9a34 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -4,6 +4,7 @@ DocumentChunk, DocumentResponse, DocumentType, + IngestionMode, IngestionStatus, KGEnrichmentStatus, KGExtractionStatus, @@ -44,6 +45,7 @@ KGGlobalResult, KGRelationshipResult, KGSearchResultType, + SearchMode, SearchSettings, ) from .user import Token, TokenData, User @@ -110,6 +112,7 @@ "ChunkSearchResult", "SearchSettings", "HybridSearchSettings", + "SearchMode", # KG abstractions "KGCreationSettings", "KGEnrichmentSettings", diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 225e9af3b..12bf02a68 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -255,3 +255,131 @@ class DocumentChunk(R2RSerializable): class RawChunk(R2RSerializable): text: str + + +class IngestionMode(str, Enum): + hi_res = "hi-res" + fast = "fast" + custom = "custom" + + +class ChunkEnrichmentStrategy(str, Enum): + SEMANTIC = "semantic" + NEIGHBORHOOD = "neighborhood" + + def __str__(self) -> str: + return self.value + + +from .llm import GenerationConfig + + +class ChunkEnrichmentSettings(R2RSerializable): + """ + Settings for chunk enrichment. + """ + + enable_chunk_enrichment: bool = Field( + default=False, + description="Whether to enable chunk enrichment or not", + ) + strategies: list[ChunkEnrichmentStrategy] = Field( + default=[], + description="The strategies to use for chunk enrichment. Union of chunks obtained from each strategy is used as context.", + ) + forward_chunks: int = Field( + default=3, + description="The number after the current chunk to include in the LLM context while enriching", + ) + backward_chunks: int = Field( + default=3, + description="The number of chunks before the current chunk in the LLM context while enriching", + ) + semantic_neighbors: int = Field( + default=10, description="The number of semantic neighbors to include" + ) + semantic_similarity_threshold: float = Field( + default=0.7, + description="The similarity threshold for semantic neighbors", + ) + generation_config: GenerationConfig = Field( + default=GenerationConfig(), + description="The generation config to use for chunk enrichment", + ) + + +## TODO - Move ingestion config + + +class IngestionConfig(R2RSerializable): + provider: str = "r2r" + excluded_parsers: list[str] = ["mp4"] + chunk_enrichment_settings: ChunkEnrichmentSettings = ( + ChunkEnrichmentSettings() + ) + extra_parsers: dict[str, str] = {} + + audio_transcription_model: str = "openai/whisper-1" + + vision_img_prompt_name: str = "vision_img" + vision_img_model: str = "openai/gpt-4o" + + vision_pdf_prompt_name: str = "vision_pdf" + vision_pdf_model: str = "openai/gpt-4o" + + skip_document_summary: bool = False + document_summary_system_prompt: str = "default_system" + document_summary_task_prompt: str = "default_summary" + chunks_for_document_summary: int = 128 + document_summary_model: str = "openai/gpt-4o-mini" + + @property + def supported_providers(self) -> list[str]: + return ["r2r", "unstructured_local", "unstructured_api"] + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider {self.provider} is not supported.") + + @classmethod + def get_default(cls, mode: str) -> "IngestionConfig": + """Return default ingestion configuration for a given mode.""" + if mode == "hi-res": + # More thorough parsing, no skipping summaries, possibly larger `chunks_for_document_summary`. + return cls( + provider="r2r", + excluded_parsers=["mp4"], + chunk_enrichment_settings=ChunkEnrichmentSettings(), # default + extra_parsers={}, + audio_transcription_model="openai/whisper-1", + vision_img_prompt_name="vision_img", + vision_img_model="openai/gpt-4o", + vision_pdf_prompt_name="vision_pdf", + vision_pdf_model="openai/gpt-4o", + skip_document_summary=False, + document_summary_system_prompt="default_system", + document_summary_task_prompt="default_summary", + chunks_for_document_summary=256, # larger for hi-res + document_summary_model="openai/gpt-4o-mini", + ) + elif mode == "fast": + # Skip summaries and other enrichment steps for speed. + return cls( + provider="r2r", + excluded_parsers=["mp4"], + chunk_enrichment_settings=ChunkEnrichmentSettings(), # default + extra_parsers={}, + audio_transcription_model="openai/whisper-1", + vision_img_prompt_name="vision_img", + vision_img_model="openai/gpt-4o", + vision_pdf_prompt_name="vision_pdf", + vision_pdf_model="openai/gpt-4o", + skip_document_summary=True, # skip summaries + document_summary_system_prompt="default_system", + document_summary_task_prompt="default_summary", + chunks_for_document_summary=64, + document_summary_model="openai/gpt-4o-mini", + ) + else: + # For `custom` or any unrecognized mode, return a base config + return cls() diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 59a7398b5..74fbac3ea 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -424,3 +424,36 @@ def __init__(self, **data): def model_dump(self, *args, **kwargs): dump = super().model_dump(*args, **kwargs) return dump + + @classmethod + def get_default(cls, mode: str) -> "SearchSettings": + """Return default search settings for a given mode.""" + if mode == "basic": + # A simpler search that relies primarily on semantic search. + return cls( + use_semantic_search=True, + use_fulltext_search=False, + use_hybrid_search=False, + search_strategy="vanilla", + # Other relevant defaults can be provided here as needed + ) + elif mode == "advanced": + # A more powerful, combined search that leverages both semantic and fulltext. + return cls( + use_semantic_search=True, + use_fulltext_search=True, + use_hybrid_search=True, + search_strategy="hyde", + # Other advanced defaults as needed + ) + else: + # For 'custom' or unrecognized modes, return a basic empty config. + return cls() + + +class SearchMode(str, Enum): + """Search modes for the search endpoint.""" + + basic = "basic" + advanced = "advanced" + custom = "custom"