Skip to content

Commit

Permalink
Fix remaining type errors (#2010)
Browse files Browse the repository at this point in the history
* Fix remaining type errors

* Add mypy install

* Fixes
  • Loading branch information
NolanTrem authored Feb 26, 2025
1 parent f368978 commit 45b9267
Show file tree
Hide file tree
Showing 20 changed files with 105 additions and 69 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pre-commit
pip install mypy
pip install types-requests types-toml types-aiofiles
- name: Run pre-commit hooks (excluding mypy)
env:
SKIP: mypy
- name: Run pre-commit hooks
run: |
pre-commit run --all-files
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
hooks:
- id: mypy
name: mypy
entry: bash -c 'cd py && python -m mypy .'
entry: bash -c 'cd py && python -m mypy --exclude "migrations" .'
language: system
types: [python]
pass_filenames: false
9 changes: 7 additions & 2 deletions py/core/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .base import R2RAgent, R2RStreamingAgent, R2RStreamingReasoningAgent
from .rag import (
# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
from .base import ( # type: ignore
R2RAgent,
R2RStreamingAgent,
R2RStreamingReasoningAgent,
)
from .rag import ( # type: ignore
R2RRAGAgent,
R2RStreamingRAGAgent,
R2RStreamingReasoningRAGAgent,
Expand Down
1 change: 1 addition & 0 deletions py/core/agent/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import asyncio
import logging
from abc import ABCMeta
Expand Down
1 change: 1 addition & 0 deletions py/core/agent/rag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import asyncio
import json
import logging
Expand Down
9 changes: 8 additions & 1 deletion py/core/base/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .agent import Agent, AgentConfig, Conversation, Tool, ToolResult
# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
from .agent import ( # type: ignore
Agent,
AgentConfig,
Conversation,
Tool,
ToolResult,
)

__all__ = [
# Agent abstractions
Expand Down
1 change: 1 addition & 0 deletions py/core/base/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import asyncio
import json
import logging
Expand Down
4 changes: 3 additions & 1 deletion py/core/base/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
class InnerConfig(BaseModel, ABC):
"""A base provider configuration class."""

extra_fields: dict[str, Any] = {}

class Config:
populate_by_name = True
arbitrary_types_allowed = True
ignore_extra = True

@classmethod
def create(cls: Type["ProviderConfig"], **kwargs: Any) -> "ProviderConfig":
def create(cls: Type["InnerConfig"], **kwargs: Any) -> "InnerConfig":
base_args = cls.model_fields.keys()
filtered_kwargs = {
k: v if v != "None" else None
Expand Down
12 changes: 7 additions & 5 deletions py/core/main/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
import logging
import os
from enum import Enum
Expand All @@ -7,7 +8,7 @@
from pydantic import BaseModel

from ..base.abstractions import GenerationConfig
from ..base.agent.agent import AgentConfig
from ..base.agent.agent import AgentConfig # type: ignore
from ..base.providers import AppConfig
from ..base.providers.auth import AuthConfig
from ..base.providers.crypto import CryptoConfig
Expand Down Expand Up @@ -123,9 +124,10 @@ def __init__(self, config_data: dict[str, Any]):
IngestionConfig.set_default(**self.ingestion.dict())

# override GenerationConfig defaults
GenerationConfig.set_default(
**self.completion.generation_config.dict()
)
if self.completion.generation_config:
GenerationConfig.set_default(
**self.completion.generation_config.dict()
)

def _validate_config_section(
self, config_data: dict[str, Any], section: str, keys: list
Expand Down Expand Up @@ -166,7 +168,7 @@ def load_default_config(cls) -> dict:
return toml.load(f)

@staticmethod
def _serialize_config(config_section: Any) -> dict:
def _serialize_config(config_section: Any):
"""Serialize config section while excluding internal state."""
if isinstance(config_section, dict):
return {
Expand Down
9 changes: 7 additions & 2 deletions py/core/main/orchestration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from .hatchet.graph_workflow import hatchet_graph_search_results_factory
from .hatchet.ingestion_workflow import hatchet_ingestion_factory
# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
from .hatchet.graph_workflow import ( # type: ignore
hatchet_graph_search_results_factory,
)
from .hatchet.ingestion_workflow import ( # type: ignore
hatchet_ingestion_factory,
)
from .simple.graph_workflow import simple_graph_search_results_factory
from .simple.ingestion_workflow import simple_ingestion_factory

Expand Down
1 change: 1 addition & 0 deletions py/core/main/orchestration/hatchet/graph_workflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import asyncio
import contextlib
import json
Expand Down
1 change: 1 addition & 0 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import asyncio
import logging
import uuid
Expand Down
9 changes: 4 additions & 5 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# FIXME: Once the agent is properly type annotated, remove the type: ignore comments
import asyncio
import json
import logging
import time
import uuid
from copy import deepcopy
from datetime import datetime
from typing import Any, Optional, cast
from typing import Any, Optional
from uuid import UUID

import tiktoken
Expand All @@ -17,7 +18,7 @@
R2RStreamingReasoningRAGAgent,
SearchResultsCollector,
)
from core.agent.rag import (
from core.agent.rag import ( # type: ignore
GeminiXMLToolsStreamingReasoningRAGAgent,
R2RXMLToolsStreamingReasoningRAGAgent,
)
Expand Down Expand Up @@ -812,9 +813,7 @@ async def agent(
)

agent_config = deepcopy(self.config.agent)
agent_config.tools = cast(
type(agent_config.tools), override_tools or agent_config.tools
)
agent_config.tools = override_tools or agent_config.tools

if rag_generation_config.stream:

Expand Down
24 changes: 16 additions & 8 deletions py/core/providers/database/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,11 +1129,19 @@ async def search_documents(
where_clauses = []
params: list[str | int | bytes] = [query_text]

search_over_body = getattr(settings, "search_over_body", True)
search_over_metadata = getattr(settings, "search_over_metadata", True)
metadata_weight = getattr(settings, "metadata_weight", 3.0)
title_weight = getattr(settings, "title_weight", 1.0)
metadata_keys = getattr(
settings, "metadata_keys", ["title", "description"]
)

# Build the dynamic metadata field search expression
metadata_fields_expr = " || ' ' || ".join(
[
f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
for key in settings.metadata_keys # type: ignore
for key in metadata_keys # type: ignore
]
)

Expand Down Expand Up @@ -1169,7 +1177,7 @@ async def search_documents(
) as body_rank
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE $1 != ''
{"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if settings.search_over_body else ""}
{"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""}
GROUP BY document_id
),
-- Combined scores with document metadata
Expand All @@ -1180,20 +1188,20 @@ async def search_documents(
COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
COALESCE(b.body_rank, 0) as debug_body_rank,
CASE
WHEN {str(settings.search_over_metadata).lower()} AND {str(settings.search_over_body).lower()} THEN
COALESCE(m.metadata_rank, 0) * {settings.metadata_weight} + COALESCE(b.body_rank, 0) * {settings.title_weight}
WHEN {str(settings.search_over_metadata).lower()} THEN
WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN
COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight}
WHEN {str(search_over_metadata).lower()} THEN
COALESCE(m.metadata_rank, 0)
WHEN {str(settings.search_over_body).lower()} THEN
WHEN {str(search_over_body).lower()} THEN
COALESCE(b.body_rank, 0)
ELSE 0
END as rank
FROM metadata_scores m
FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
WHERE (
($1 = '') OR
({str(settings.search_over_metadata).lower()} AND m.metadata_rank > 0) OR
({str(settings.search_over_body).lower()} AND b.body_rank > 0)
({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR
({str(search_over_body).lower()} AND b.body_rank > 0)
)
"""

Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ async def _create_merged_entity(self, entities: list[Entity]) -> Entity:
)

# Merge metadata dictionaries
merged_metadata = {}
merged_metadata: dict[str, Any] = {}
for entity in entities:
if entity.metadata:
merged_metadata |= entity.metadata
Expand Down
14 changes: 8 additions & 6 deletions py/core/providers/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .users import PostgresUserHandler

if TYPE_CHECKING:
from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider
from ..crypto import BCryptCryptoProvider, NaClCryptoProvider

CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider

Expand Down Expand Up @@ -134,7 +134,9 @@ def __init__(
PostgresConnectionManager()
)
self.documents_handler = PostgresDocumentsHandler(
self.project_name, self.connection_manager, self.dimension
project_name=self.project_name,
connection_manager=self.connection_manager,
dimension=int(self.dimension),
)
self.token_handler = PostgresTokensHandler(
self.project_name, self.connection_manager
Expand All @@ -146,10 +148,10 @@ def __init__(
self.project_name, self.connection_manager, self.crypto_provider
)
self.chunks_handler = PostgresChunksHandler(
self.project_name,
self.connection_manager,
self.dimension,
self.quantization_type,
project_name=self.project_name,
connection_manager=self.connection_manager,
dimension=int(self.dimension),
quantization_type=(self.quantization_type),
)
self.conversations_handler = PostgresConversationsHandler(
self.project_name, self.connection_manager
Expand Down
50 changes: 25 additions & 25 deletions py/core/providers/ingestion/unstructured/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from copy import copy
from io import BytesIO
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator

import httpx
from unstructured_client import UnstructuredClient
Expand Down Expand Up @@ -46,30 +46,30 @@ class UnstructuredIngestionConfig(IngestionConfig):
new_after_n_chars: int = 1500
overlap: int = 64

coordinates: Optional[bool] = None
encoding: Optional[str] = None # utf-8
extract_image_block_types: Optional[list[str]] = None
gz_uncompressed_content_type: Optional[str] = None
hi_res_model_name: Optional[str] = None
include_orig_elements: Optional[bool] = None
include_page_breaks: Optional[bool] = None

languages: Optional[list[str]] = None
multipage_sections: Optional[bool] = None
ocr_languages: Optional[list[str]] = None
coordinates: bool | None
encoding: str | None # utf-8
extract_image_block_types: list[str] | None
gz_uncompressed_content_type: str | None
hi_res_model_name: str | None
include_orig_elements: bool | None
include_page_breaks: bool | None

languages: list[str] | None
multipage_sections: bool | None
ocr_languages: list[str] | None
# output_format: Optional[str] = "application/json"
overlap_all: Optional[bool] = None
pdf_infer_table_structure: Optional[bool] = None

similarity_threshold: Optional[float] = None
skip_infer_table_types: Optional[list[str]] = None
split_pdf_concurrency_level: Optional[int] = None
split_pdf_page: Optional[bool] = None
starting_page_number: Optional[int] = None
strategy: Optional[str] = None
chunking_strategy: Optional[str | ChunkingStrategy] = None
unique_element_ids: Optional[bool] = None
xml_keep_tags: Optional[bool] = None
overlap_all: bool | None
pdf_infer_table_structure: bool | None

similarity_threshold: float | None
skip_infer_table_types: list[str] | None
split_pdf_concurrency_level: int | None
split_pdf_page: bool | None
starting_page_number: int | None
strategy: str | None
chunking_strategy: str | ChunkingStrategy | None # type: ignore
unique_element_ids: bool | None
xml_keep_tags: bool | None

def to_ingestion_request(self):
import json
Expand Down Expand Up @@ -204,7 +204,7 @@ async def parse_fallback(
parser_name: str,
) -> AsyncGenerator[FallbackElement, None]:
contents = []
async for chunk in self.parsers[parser_name].ingest(
async for chunk in self.parsers[parser_name].ingest( # type: ignore
file_content, **ingestion_config
): # type: ignore
if isinstance(chunk, dict) and chunk.get("content"):
Expand Down
11 changes: 6 additions & 5 deletions py/core/providers/orchestration/hatchet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# FIXME: Once the Hatchet workflows are type annotated, remove the type: ignore comments
import asyncio
import logging
from typing import Any, Callable, Optional
Expand Down Expand Up @@ -39,7 +40,7 @@ def failure(self, *args, **kwargs) -> Callable:
def get_worker(self, name: str, max_runs: Optional[int] = None) -> Any:
if not max_runs:
max_runs = self.config.max_runs
self.worker = self.orchestrator.worker(name, max_runs)
self.worker = self.orchestrator.worker(name, max_runs) # type: ignore
return self.worker

def concurrency(self, *args, **kwargs) -> Callable:
Expand All @@ -61,10 +62,10 @@ async def run_workflow(
*args,
**kwargs,
) -> Any:
task_id = self.orchestrator.admin.run_workflow(
task_id = self.orchestrator.admin.run_workflow( # type: ignore
workflow_name,
parameters,
options=options,
options=options, # type: ignore
*args,
**kwargs,
)
Expand All @@ -84,7 +85,7 @@ def register_workflows(
f"Registering workflows for {workflow} with messages {messages}."
)
if workflow == Workflow.INGESTION:
from core.main.orchestration.hatchet.ingestion_workflow import (
from core.main.orchestration.hatchet.ingestion_workflow import ( # type: ignore
hatchet_ingestion_factory,
)

Expand All @@ -94,7 +95,7 @@ def register_workflows(
self.worker.register_workflow(workflow)

elif workflow == Workflow.GRAPH:
from core.main.orchestration.hatchet.graph_workflow import (
from core.main.orchestration.hatchet.graph_workflow import ( # type: ignore
hatchet_graph_search_results_factory,
)

Expand Down
Loading

0 comments on commit 45b9267

Please sign in to comment.