From 45b92670786067fc1304cd63227740ad05402b9c Mon Sep 17 00:00:00 2001 From: Nolan Tremelling <34580718+NolanTrem@users.noreply.github.com> Date: Wed, 26 Feb 2025 13:39:46 -0800 Subject: [PATCH] Fix remaining type errors (#2010) * Fix remaining type errors * Add mypy install * Fixes --- .github/workflows/quality.yml | 6 +-- .pre-commit-config.yaml | 2 +- py/core/agent/__init__.py | 9 +++- py/core/agent/base.py | 1 + py/core/agent/rag.py | 1 + py/core/base/agent/__init__.py | 9 +++- py/core/base/agent/agent.py | 1 + py/core/base/providers/base.py | 4 +- py/core/main/config.py | 12 +++-- py/core/main/orchestration/__init__.py | 9 +++- .../orchestration/hatchet/graph_workflow.py | 1 + .../hatchet/ingestion_workflow.py | 1 + py/core/main/services/retrieval_service.py | 9 ++-- py/core/providers/database/chunks.py | 24 ++++++--- py/core/providers/database/graphs.py | 2 +- py/core/providers/database/postgres.py | 14 +++--- .../providers/ingestion/unstructured/base.py | 50 +++++++++---------- py/core/providers/orchestration/hatchet.py | 11 ++-- py/shared/abstractions/graph.py | 2 +- py/shared/abstractions/llm.py | 6 +-- 20 files changed, 105 insertions(+), 69 deletions(-) diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index be5530c68..e66e3076d 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -20,9 +20,9 @@ jobs: run: | python -m pip install --upgrade pip pip install pre-commit + pip install mypy + pip install types-requests types-toml types-aiofiles - - name: Run pre-commit hooks (excluding mypy) - env: - SKIP: mypy + - name: Run pre-commit hooks run: | pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f34a5e73d..5432e3ee0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: hooks: - id: mypy name: mypy - entry: bash -c 'cd py && python -m mypy .' + entry: bash -c 'cd py && python -m mypy --exclude "migrations" .' language: system types: [python] pass_filenames: false diff --git a/py/core/agent/__init__.py b/py/core/agent/__init__.py index ba2eff690..c3260f5fc 100644 --- a/py/core/agent/__init__.py +++ b/py/core/agent/__init__.py @@ -1,5 +1,10 @@ -from .base import R2RAgent, R2RStreamingAgent, R2RStreamingReasoningAgent -from .rag import ( +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +from .base import ( # type: ignore + R2RAgent, + R2RStreamingAgent, + R2RStreamingReasoningAgent, +) +from .rag import ( # type: ignore R2RRAGAgent, R2RStreamingRAGAgent, R2RStreamingReasoningRAGAgent, diff --git a/py/core/agent/base.py b/py/core/agent/base.py index 0856a92e9..810a5ed35 100644 --- a/py/core/agent/base.py +++ b/py/core/agent/base.py @@ -1,3 +1,4 @@ +# type: ignore import asyncio import logging from abc import ABCMeta diff --git a/py/core/agent/rag.py b/py/core/agent/rag.py index 6a9cfb881..dbf7874d0 100644 --- a/py/core/agent/rag.py +++ b/py/core/agent/rag.py @@ -1,3 +1,4 @@ +# type: ignore import asyncio import json import logging diff --git a/py/core/base/agent/__init__.py b/py/core/base/agent/__init__.py index 53414f437..815b9ae7f 100644 --- a/py/core/base/agent/__init__.py +++ b/py/core/base/agent/__init__.py @@ -1,4 +1,11 @@ -from .agent import Agent, AgentConfig, Conversation, Tool, ToolResult +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments +from .agent import ( # type: ignore + Agent, + AgentConfig, + Conversation, + Tool, + ToolResult, +) __all__ = [ # Agent abstractions diff --git a/py/core/base/agent/agent.py b/py/core/base/agent/agent.py index 1fad4ce2e..01a82159b 100644 --- a/py/core/base/agent/agent.py +++ b/py/core/base/agent/agent.py @@ -1,3 +1,4 @@ +# type: ignore import asyncio import json import logging diff --git a/py/core/base/providers/base.py b/py/core/base/providers/base.py index f90869cdb..1c29ee443 100644 --- a/py/core/base/providers/base.py +++ b/py/core/base/providers/base.py @@ -7,13 +7,15 @@ class InnerConfig(BaseModel, ABC): """A base provider configuration class.""" + extra_fields: dict[str, Any] = {} + class Config: populate_by_name = True arbitrary_types_allowed = True ignore_extra = True @classmethod - def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig": + def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig": base_args = cls.model_fields.keys() filtered_kwargs = { k: v if v != "None" else None diff --git a/py/core/main/config.py b/py/core/main/config.py index 2a0306205..831988616 100644 --- a/py/core/main/config.py +++ b/py/core/main/config.py @@ -1,3 +1,4 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments import logging import os from enum import Enum @@ -7,7 +8,7 @@ from pydantic import BaseModel from ..base.abstractions import GenerationConfig -from ..base.agent.agent import AgentConfig +from ..base.agent.agent import AgentConfig # type: ignore from ..base.providers import AppConfig from ..base.providers.auth import AuthConfig from ..base.providers.crypto import CryptoConfig @@ -123,9 +124,10 @@ def __init__(self, config_data: dict[str, Any]): IngestionConfig.set_default(**self.ingestion.dict()) # override GenerationConfig defaults - GenerationConfig.set_default( - **self.completion.generation_config.dict() - ) + if self.completion.generation_config: + GenerationConfig.set_default( + **self.completion.generation_config.dict() + ) def _validate_config_section( self, config_data: dict[str, Any], section: str, keys: list @@ -166,7 +168,7 @@ def load_default_config(cls) -> dict: return toml.load(f) @staticmethod - def _serialize_config(config_section: Any) -> dict: + def _serialize_config(config_section: Any): """Serialize config section while excluding internal state.""" if isinstance(config_section, dict): return { diff --git a/py/core/main/orchestration/__init__.py b/py/core/main/orchestration/__init__.py index 8cd77be41..19cb04280 100644 --- a/py/core/main/orchestration/__init__.py +++ b/py/core/main/orchestration/__init__.py @@ -1,5 +1,10 @@ -from .hatchet.graph_workflow import hatchet_graph_search_results_factory -from .hatchet.ingestion_workflow import hatchet_ingestion_factory +# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments +from .hatchet.graph_workflow import ( # type: ignore + hatchet_graph_search_results_factory, +) +from .hatchet.ingestion_workflow import ( # type: ignore + hatchet_ingestion_factory, +) from .simple.graph_workflow import simple_graph_search_results_factory from .simple.ingestion_workflow import simple_ingestion_factory diff --git a/py/core/main/orchestration/hatchet/graph_workflow.py b/py/core/main/orchestration/hatchet/graph_workflow.py index 26de8bcc1..cc128b0fb 100644 --- a/py/core/main/orchestration/hatchet/graph_workflow.py +++ b/py/core/main/orchestration/hatchet/graph_workflow.py @@ -1,3 +1,4 @@ +# type: ignore import asyncio import contextlib import json diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index 2e839c2a6..9fa935ab4 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -1,3 +1,4 @@ +# type: ignore import asyncio import logging import uuid diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index b079c8d98..960fe422a 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -1,3 +1,4 @@ +# FIXME: Once the agent is properly type annotated, remove the type: ignore comments import asyncio import json import logging @@ -5,7 +6,7 @@ import uuid from copy import deepcopy from datetime import datetime -from typing import Any, Optional, cast +from typing import Any, Optional from uuid import UUID import tiktoken @@ -17,7 +18,7 @@ R2RStreamingReasoningRAGAgent, SearchResultsCollector, ) -from core.agent.rag import ( +from core.agent.rag import ( # type: ignore GeminiXMLToolsStreamingReasoningRAGAgent, R2RXMLToolsStreamingReasoningRAGAgent, ) @@ -812,9 +813,7 @@ async def agent( ) agent_config = deepcopy(self.config.agent) - agent_config.tools = cast( - type(agent_config.tools), override_tools or agent_config.tools - ) + agent_config.tools = override_tools or agent_config.tools if rag_generation_config.stream: diff --git a/py/core/providers/database/chunks.py b/py/core/providers/database/chunks.py index 8bb22d0ac..404e359ce 100644 --- a/py/core/providers/database/chunks.py +++ b/py/core/providers/database/chunks.py @@ -1129,11 +1129,19 @@ async def search_documents( where_clauses = [] params: list[str | int | bytes] = [query_text] + search_over_body = getattr(settings, "search_over_body", True) + search_over_metadata = getattr(settings, "search_over_metadata", True) + metadata_weight = getattr(settings, "metadata_weight", 3.0) + title_weight = getattr(settings, "title_weight", 1.0) + metadata_keys = getattr( + settings, "metadata_keys", ["title", "description"] + ) + # Build the dynamic metadata field search expression metadata_fields_expr = " || ' ' || ".join( [ f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')" - for key in settings.metadata_keys # type: ignore + for key in metadata_keys # type: ignore ] ) @@ -1169,7 +1177,7 @@ async def search_documents( ) as body_rank FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} WHERE $1 != '' - {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""} + {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""} GROUP BY document_id ), -- Combined scores with document metadata @@ -1180,11 +1188,11 @@ async def search_documents( COALESCE(m.metadata_rank, 0) as debug_metadata_rank, COALESCE(b.body_rank, 0) as debug_body_rank, CASE - WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN - COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight} - WHEN {str(settings.search_over_metadata).lower()} THEN + WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN + COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight} + WHEN {str(search_over_metadata).lower()} THEN COALESCE(m.metadata_rank, 0) - WHEN {str(settings.search_over_body).lower()} THEN + WHEN {str(search_over_body).lower()} THEN COALESCE(b.body_rank, 0) ELSE 0 END as rank @@ -1192,8 +1200,8 @@ async def search_documents( FULL OUTER JOIN body_scores b ON m.document_id = b.document_id WHERE ( ($1 = '') OR - ({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR - ({str(settings.search_over_body).lower()} AND b.body_rank > 0) + ({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR + ({str(search_over_body).lower()} AND b.body_rank > 0) ) """ diff --git a/py/core/providers/database/graphs.py b/py/core/providers/database/graphs.py index d7857fd3e..a5a54a508 100644 --- a/py/core/providers/database/graphs.py +++ b/py/core/providers/database/graphs.py @@ -542,7 +542,7 @@ async def _create_merged_entity(self, entities: list[Entity]) -> Entity: ) # Merge metadata dictionaries - merged_metadata = {} + merged_metadata: dict[str, Any] = {} for entity in entities: if entity.metadata: merged_metadata |= entity.metadata diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index 7b3c8fac0..f1d84a131 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -27,7 +27,7 @@ from .users import PostgresUserHandler if TYPE_CHECKING: - from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider + from ..crypto import BCryptCryptoProvider, NaClCryptoProvider CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider @@ -134,7 +134,9 @@ def __init__( PostgresConnectionManager() ) self.documents_handler = PostgresDocumentsHandler( - self.project_name, self.connection_manager, self.dimension + project_name=self.project_name, + connection_manager=self.connection_manager, + dimension=int(self.dimension), ) self.token_handler = PostgresTokensHandler( self.project_name, self.connection_manager @@ -146,10 +148,10 @@ def __init__( self.project_name, self.connection_manager, self.crypto_provider ) self.chunks_handler = PostgresChunksHandler( - self.project_name, - self.connection_manager, - self.dimension, - self.quantization_type, + project_name=self.project_name, + connection_manager=self.connection_manager, + dimension=int(self.dimension), + quantization_type=(self.quantization_type), ) self.conversations_handler = PostgresConversationsHandler( self.project_name, self.connection_manager diff --git a/py/core/providers/ingestion/unstructured/base.py b/py/core/providers/ingestion/unstructured/base.py index 6dc59a337..541a05f9b 100644 --- a/py/core/providers/ingestion/unstructured/base.py +++ b/py/core/providers/ingestion/unstructured/base.py @@ -6,7 +6,7 @@ import time from copy import copy from io import BytesIO -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator import httpx from unstructured_client import UnstructuredClient @@ -46,30 +46,30 @@ class UnstructuredIngestionConfig(IngestionConfig): new_after_n_chars: int = 1500 overlap: int = 64 - coordinates: Optional[bool] = None - encoding: Optional[str] = None # utf-8 - extract_image_block_types: Optional[list[str]] = None - gz_uncompressed_content_type: Optional[str] = None - hi_res_model_name: Optional[str] = None - include_orig_elements: Optional[bool] = None - include_page_breaks: Optional[bool] = None - - languages: Optional[list[str]] = None - multipage_sections: Optional[bool] = None - ocr_languages: Optional[list[str]] = None + coordinates: bool | None + encoding: str | None # utf-8 + extract_image_block_types: list[str] | None + gz_uncompressed_content_type: str | None + hi_res_model_name: str | None + include_orig_elements: bool | None + include_page_breaks: bool | None + + languages: list[str] | None + multipage_sections: bool | None + ocr_languages: list[str] | None # output_format: Optional[str] = "application/json" - overlap_all: Optional[bool] = None - pdf_infer_table_structure: Optional[bool] = None - - similarity_threshold: Optional[float] = None - skip_infer_table_types: Optional[list[str]] = None - split_pdf_concurrency_level: Optional[int] = None - split_pdf_page: Optional[bool] = None - starting_page_number: Optional[int] = None - strategy: Optional[str] = None - chunking_strategy: Optional[str | ChunkingStrategy] = None - unique_element_ids: Optional[bool] = None - xml_keep_tags: Optional[bool] = None + overlap_all: bool | None + pdf_infer_table_structure: bool | None + + similarity_threshold: float | None + skip_infer_table_types: list[str] | None + split_pdf_concurrency_level: int | None + split_pdf_page: bool | None + starting_page_number: int | None + strategy: str | None + chunking_strategy: str | ChunkingStrategy | None # type: ignore + unique_element_ids: bool | None + xml_keep_tags: bool | None def to_ingestion_request(self): import json @@ -204,7 +204,7 @@ async def parse_fallback( parser_name: str, ) -> AsyncGenerator[FallbackElement, None]: contents = [] - async for chunk in self.parsers[parser_name].ingest( + async for chunk in self.parsers[parser_name].ingest( # type: ignore file_content, **ingestion_config ): # type: ignore if isinstance(chunk, dict) and chunk.get("content"): diff --git a/py/core/providers/orchestration/hatchet.py b/py/core/providers/orchestration/hatchet.py index 210367777..941e2048d 100644 --- a/py/core/providers/orchestration/hatchet.py +++ b/py/core/providers/orchestration/hatchet.py @@ -1,3 +1,4 @@ +# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments import asyncio import logging from typing import Any, Callable, Optional @@ -39,7 +40,7 @@ def failure(self, *args, **kwargs) -> Callable: def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any: if not max_runs: max_runs = self.config.max_runs - self.worker = self.orchestrator.worker(name, max_runs) + self.worker = self.orchestrator.worker(name, max_runs) # type: ignore return self.worker def concurrency(self, *args, **kwargs) -> Callable: @@ -61,10 +62,10 @@ async def run_workflow( *args, **kwargs, ) -> Any: - task_id = self.orchestrator.admin.run_workflow( + task_id = self.orchestrator.admin.run_workflow( # type: ignore workflow_name, parameters, - options=options, + options=options, # type: ignore *args, **kwargs, ) @@ -84,7 +85,7 @@ def register_workflows( f"Registering workflows for {workflow} with messages {messages}." ) if workflow == Workflow.INGESTION: - from core.main.orchestration.hatchet.ingestion_workflow import ( + from core.main.orchestration.hatchet.ingestion_workflow import ( # type: ignore hatchet_ingestion_factory, ) @@ -94,7 +95,7 @@ def register_workflows( self.worker.register_workflow(workflow) elif workflow == Workflow.GRAPH: - from core.main.orchestration.hatchet.graph_workflow import ( + from core.main.orchestration.hatchet.graph_workflow import ( # type: ignore hatchet_graph_search_results_factory, ) diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 0b8fd2364..aaa5fe7a9 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -17,7 +17,7 @@ class Entity(R2RSerializable): name: str description: Optional[str] = None category: Optional[str] = None - metadata: Optional[dict[str, Any] | str] = None + metadata: Optional[dict[str, Any]] = None id: Optional[UUID] = None parent_id: Optional[UUID] = None # graph_id | document_id diff --git a/py/shared/abstractions/llm.py b/py/shared/abstractions/llm.py index fb84c98c6..56da66dd6 100644 --- a/py/shared/abstractions/llm.py +++ b/py/shared/abstractions/llm.py @@ -2,7 +2,7 @@ import json from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeAlias from openai.types.chat import ChatCompletion, ChatCompletionChunk from pydantic import BaseModel, Field @@ -12,8 +12,8 @@ if TYPE_CHECKING: from .search import AggregateSearchResult -LLMChatCompletion = ChatCompletion -LLMChatCompletionChunk = ChatCompletionChunk +LLMChatCompletion: TypeAlias = ChatCompletion +LLMChatCompletionChunk: TypeAlias = ChatCompletionChunk class RAGCompletion: