diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index 6f146c627..ae92a5ef4 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Optional, Union +from typing import Any, Optional from core.agent import R2RRAGAgent, R2RStreamingRAGAgent from core.base import ( @@ -25,8 +25,10 @@ logger = logging.getLogger() from core.database import PostgresDatabaseProvider -from core.providers import ( # PostgresDatabaseProvider, +from core.providers import ( AsyncSMTPEmailProvider, + BcryptCryptoConfig, + BCryptCryptoProvider, ConsoleMockEmailProvider, HatchetOrchestrationProvider, LiteLLMCompletionProvider, @@ -53,16 +55,16 @@ def __init__(self, config: R2RConfig): @staticmethod async def create_auth_provider( auth_config: AuthConfig, - crypto_provider: NaClCryptoProvider, + crypto_provider: BCryptCryptoProvider | NaClCryptoProvider, database_provider: PostgresDatabaseProvider, - email_provider: Union[ - AsyncSMTPEmailProvider, - ConsoleMockEmailProvider, - SendGridEmailProvider, - ], + email_provider: ( + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider + ), *args, **kwargs, - ) -> Union[R2RAuthProvider, SupabaseAuthProvider]: + ) -> R2RAuthProvider | SupabaseAuthProvider: if auth_config.provider == "r2r": r2r_auth = R2RAuthProvider( @@ -82,9 +84,15 @@ async def create_auth_provider( @staticmethod def create_crypto_provider( crypto_config: CryptoConfig, *args, **kwargs - ) -> NaClCryptoProvider: + ) -> BCryptCryptoProvider | NaClCryptoProvider: if crypto_config.provider == "bcrypt": - return NaClCryptoProvider(NaClCryptoConfig(**crypto_config.dict())) + return BCryptCryptoProvider( + BcryptCryptoConfig(**crypto_config.model_dump()) + ) + if crypto_config.provider == "nacl": + return NaClCryptoProvider( + NaClCryptoConfig(**crypto_config.model_dump()) + ) else: raise ValueError( f"Crypto provider {crypto_config.provider} not supported." @@ -94,12 +102,10 @@ def create_crypto_provider( def create_ingestion_provider( ingestion_config: IngestionConfig, database_provider: PostgresDatabaseProvider, - llm_provider: Union[ - LiteLLMCompletionProvider, OpenAICompletionProvider - ], + llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider, *args, **kwargs, - ) -> Union[R2RIngestionProvider, UnstructuredIngestionProvider]: + ) -> R2RIngestionProvider | UnstructuredIngestionProvider: config_dict = ( ingestion_config.model_dump() @@ -135,7 +141,7 @@ def create_ingestion_provider( @staticmethod def create_orchestration_provider( config: OrchestrationConfig, *args, **kwargs - ) -> Union[HatchetOrchestrationProvider, SimpleOrchestrationProvider]: + ) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider: if config.provider == "hatchet": orchestration_provider = HatchetOrchestrationProvider(config) orchestration_provider.get_worker("r2r-worker") @@ -152,7 +158,7 @@ def create_orchestration_provider( async def create_database_provider( self, db_config: DatabaseConfig, - crypto_provider: NaClCryptoProvider, + crypto_provider: BCryptCryptoProvider | NaClCryptoProvider, *args, **kwargs, ) -> PostgresDatabaseProvider: @@ -184,11 +190,11 @@ async def create_database_provider( @staticmethod def create_embedding_provider( embedding: EmbeddingConfig, *args, **kwargs - ) -> Union[ - LiteLLMEmbeddingProvider, - OllamaEmbeddingProvider, - OpenAIEmbeddingProvider, - ]: + ) -> ( + LiteLLMEmbeddingProvider + | OllamaEmbeddingProvider + | OpenAIEmbeddingProvider + ): embedding_provider: Optional[EmbeddingProvider] = None if embedding.provider == "openai": @@ -220,7 +226,7 @@ def create_embedding_provider( @staticmethod def create_llm_provider( llm_config: CompletionConfig, *args, **kwargs - ) -> Union[LiteLLMCompletionProvider, OpenAICompletionProvider]: + ) -> LiteLLMCompletionProvider | OpenAICompletionProvider: llm_provider: Optional[CompletionProvider] = None if llm_config.provider == "openai": llm_provider = OpenAICompletionProvider(llm_config) @@ -237,13 +243,15 @@ def create_llm_provider( @staticmethod async def create_email_provider( email_config: Optional[EmailConfig] = None, *args, **kwargs - ) -> Union[ - AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider - ]: + ) -> ( + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider + ): """Creates an email provider based on configuration.""" if not email_config: raise ValueError( - f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." + "No email configuration provided for email provider, please add `[email]` to your `r2r.toml`." ) if email_config.provider == "smtp": @@ -260,29 +268,27 @@ async def create_email_provider( async def create_providers( self, auth_provider_override: Optional[ - Union[R2RAuthProvider, SupabaseAuthProvider] + R2RAuthProvider | SupabaseAuthProvider + ] = None, + crypto_provider_override: Optional[ + BCryptCryptoProvider | NaClCryptoProvider ] = None, - crypto_provider_override: Optional[NaClCryptoProvider] = None, database_provider_override: Optional[PostgresDatabaseProvider] = None, email_provider_override: Optional[ - Union[ - AsyncSMTPEmailProvider, - ConsoleMockEmailProvider, - SendGridEmailProvider, - ] + AsyncSMTPEmailProvider + | ConsoleMockEmailProvider + | SendGridEmailProvider ] = None, embedding_provider_override: Optional[ - Union[ - LiteLLMEmbeddingProvider, - OpenAIEmbeddingProvider, - OllamaEmbeddingProvider, - ] + LiteLLMEmbeddingProvider + | OpenAIEmbeddingProvider + | OllamaEmbeddingProvider ] = None, ingestion_provider_override: Optional[ - Union[R2RIngestionProvider, UnstructuredIngestionProvider] + R2RIngestionProvider | UnstructuredIngestionProvider ] = None, llm_provider_override: Optional[ - Union[OpenAICompletionProvider, LiteLLMCompletionProvider] + OpenAICompletionProvider | LiteLLMCompletionProvider ] = None, orchestration_provider_override: Optional[Any] = None, *args, diff --git a/py/core/providers/crypto/bcrypt.py b/py/core/providers/crypto/bcrypt.py index 02cc3f47e..dccafe3e3 100644 --- a/py/core/providers/crypto/bcrypt.py +++ b/py/core/providers/crypto/bcrypt.py @@ -13,6 +13,8 @@ from core.base import CryptoConfig, CryptoProvider +DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager + class BcryptCryptoConfig(CryptoConfig): provider: str = "bcrypt" @@ -29,11 +31,20 @@ def validate_config(self) -> None: super().validate_config() if self.provider not in self.supported_providers: raise ValueError(f"Unsupported crypto provider: {self.provider}") - if not self.secret_key: - # In production, fail here if no secret key is provided. - raise ValueError( - "No secret key provided for BcryptCryptoProvider." - ) + if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31: + raise ValueError("bcrypt_rounds must be between 4 and 31") + + def verify_password( + self, plain_password: str, hashed_password: str + ) -> bool: + try: + # First try to decode as base64 (new format) + stored_hash = base64.b64decode(hashed_password.encode("utf-8")) + except: + # If that fails, treat as raw bcrypt hash (old format) + stored_hash = hashed_password.encode("utf-8") + + return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash) class BCryptCryptoProvider(CryptoProvider, ABC): @@ -47,7 +58,11 @@ def __init__(self, config: BcryptCryptoConfig): # Load the secret key for JWT # No fallback defaults: fail if not provided - self.secret_key = self.config.secret_key + self.secret_key = ( + config.secret_key + or os.getenv("R2R_SECRET_KEY") + or DEFAULT_BCRYPT_SECRET_KEY + ) if not self.secret_key: raise ValueError( "No secret key provided for BcryptCryptoProvider." @@ -64,7 +79,13 @@ def get_password_hash(self, password: str) -> str: def verify_password( self, plain_password: str, hashed_password: str ) -> bool: - stored_hash = base64.b64decode(hashed_password.encode("utf-8")) + try: + # First try to decode as base64 (new format) + stored_hash = base64.b64decode(hashed_password.encode("utf-8")) + except: + # If that fails, treat as raw bcrypt hash (old format) + stored_hash = hashed_password.encode("utf-8") + return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash) def generate_verification_code(self, length: int = 32) -> str: