Skip to content

Commit

Permalink
Fix Crypto Providers in Factory (SciPhi-AI#1727)
Browse files Browse the repository at this point in the history
* Fix crypto provider factory implementation

* Fix initialization in bcrypt
  • Loading branch information
NolanTrem authored Dec 24, 2024
1 parent e0e0eaa commit 0ddfd13
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 48 deletions.
88 changes: 47 additions & 41 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, Optional, Union
from typing import Any, Optional

from core.agent import R2RRAGAgent, R2RStreamingRAGAgent
from core.base import (
Expand All @@ -25,8 +25,10 @@

logger = logging.getLogger()
from core.database import PostgresDatabaseProvider
from core.providers import ( # PostgresDatabaseProvider,
from core.providers import (
AsyncSMTPEmailProvider,
BcryptCryptoConfig,
BCryptCryptoProvider,
ConsoleMockEmailProvider,
HatchetOrchestrationProvider,
LiteLLMCompletionProvider,
Expand All @@ -53,16 +55,16 @@ def __init__(self, config: R2RConfig):
@staticmethod
async def create_auth_provider(
auth_config: AuthConfig,
crypto_provider: NaClCryptoProvider,
crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
database_provider: PostgresDatabaseProvider,
email_provider: Union[
AsyncSMTPEmailProvider,
ConsoleMockEmailProvider,
SendGridEmailProvider,
],
email_provider: (
AsyncSMTPEmailProvider
| ConsoleMockEmailProvider
| SendGridEmailProvider
),
*args,
**kwargs,
) -> Union[R2RAuthProvider, SupabaseAuthProvider]:
) -> R2RAuthProvider | SupabaseAuthProvider:
if auth_config.provider == "r2r":

r2r_auth = R2RAuthProvider(
Expand All @@ -82,9 +84,15 @@ async def create_auth_provider(
@staticmethod
def create_crypto_provider(
crypto_config: CryptoConfig, *args, **kwargs
) -> NaClCryptoProvider:
) -> BCryptCryptoProvider | NaClCryptoProvider:
if crypto_config.provider == "bcrypt":
return NaClCryptoProvider(NaClCryptoConfig(**crypto_config.dict()))
return BCryptCryptoProvider(
BcryptCryptoConfig(**crypto_config.model_dump())
)
if crypto_config.provider == "nacl":
return NaClCryptoProvider(
NaClCryptoConfig(**crypto_config.model_dump())
)
else:
raise ValueError(
f"Crypto provider {crypto_config.provider} not supported."
Expand All @@ -94,12 +102,10 @@ def create_crypto_provider(
def create_ingestion_provider(
ingestion_config: IngestionConfig,
database_provider: PostgresDatabaseProvider,
llm_provider: Union[
LiteLLMCompletionProvider, OpenAICompletionProvider
],
llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
*args,
**kwargs,
) -> Union[R2RIngestionProvider, UnstructuredIngestionProvider]:
) -> R2RIngestionProvider | UnstructuredIngestionProvider:

config_dict = (
ingestion_config.model_dump()
Expand Down Expand Up @@ -135,7 +141,7 @@ def create_ingestion_provider(
@staticmethod
def create_orchestration_provider(
config: OrchestrationConfig, *args, **kwargs
) -> Union[HatchetOrchestrationProvider, SimpleOrchestrationProvider]:
) -> HatchetOrchestrationProvider | SimpleOrchestrationProvider:
if config.provider == "hatchet":
orchestration_provider = HatchetOrchestrationProvider(config)
orchestration_provider.get_worker("r2r-worker")
Expand All @@ -152,7 +158,7 @@ def create_orchestration_provider(
async def create_database_provider(
self,
db_config: DatabaseConfig,
crypto_provider: NaClCryptoProvider,
crypto_provider: BCryptCryptoProvider | NaClCryptoProvider,
*args,
**kwargs,
) -> PostgresDatabaseProvider:
Expand Down Expand Up @@ -184,11 +190,11 @@ async def create_database_provider(
@staticmethod
def create_embedding_provider(
embedding: EmbeddingConfig, *args, **kwargs
) -> Union[
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
OpenAIEmbeddingProvider,
]:
) -> (
LiteLLMEmbeddingProvider
| OllamaEmbeddingProvider
| OpenAIEmbeddingProvider
):
embedding_provider: Optional[EmbeddingProvider] = None

if embedding.provider == "openai":
Expand Down Expand Up @@ -220,7 +226,7 @@ def create_embedding_provider(
@staticmethod
def create_llm_provider(
llm_config: CompletionConfig, *args, **kwargs
) -> Union[LiteLLMCompletionProvider, OpenAICompletionProvider]:
) -> LiteLLMCompletionProvider | OpenAICompletionProvider:
llm_provider: Optional[CompletionProvider] = None
if llm_config.provider == "openai":
llm_provider = OpenAICompletionProvider(llm_config)
Expand All @@ -237,13 +243,15 @@ def create_llm_provider(
@staticmethod
async def create_email_provider(
email_config: Optional[EmailConfig] = None, *args, **kwargs
) -> Union[
AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider
]:
) -> (
AsyncSMTPEmailProvider
| ConsoleMockEmailProvider
| SendGridEmailProvider
):
"""Creates an email provider based on configuration."""
if not email_config:
raise ValueError(
f"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
"No email configuration provided for email provider, please add `[email]` to your `r2r.toml`."
)

if email_config.provider == "smtp":
Expand All @@ -260,29 +268,27 @@ async def create_email_provider(
async def create_providers(
self,
auth_provider_override: Optional[
Union[R2RAuthProvider, SupabaseAuthProvider]
R2RAuthProvider | SupabaseAuthProvider
] = None,
crypto_provider_override: Optional[
BCryptCryptoProvider | NaClCryptoProvider
] = None,
crypto_provider_override: Optional[NaClCryptoProvider] = None,
database_provider_override: Optional[PostgresDatabaseProvider] = None,
email_provider_override: Optional[
Union[
AsyncSMTPEmailProvider,
ConsoleMockEmailProvider,
SendGridEmailProvider,
]
AsyncSMTPEmailProvider
| ConsoleMockEmailProvider
| SendGridEmailProvider
] = None,
embedding_provider_override: Optional[
Union[
LiteLLMEmbeddingProvider,
OpenAIEmbeddingProvider,
OllamaEmbeddingProvider,
]
LiteLLMEmbeddingProvider
| OpenAIEmbeddingProvider
| OllamaEmbeddingProvider
] = None,
ingestion_provider_override: Optional[
Union[R2RIngestionProvider, UnstructuredIngestionProvider]
R2RIngestionProvider | UnstructuredIngestionProvider
] = None,
llm_provider_override: Optional[
Union[OpenAICompletionProvider, LiteLLMCompletionProvider]
OpenAICompletionProvider | LiteLLMCompletionProvider
] = None,
orchestration_provider_override: Optional[Any] = None,
*args,
Expand Down
35 changes: 28 additions & 7 deletions py/core/providers/crypto/bcrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from core.base import CryptoConfig, CryptoProvider

DEFAULT_BCRYPT_SECRET_KEY = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM" # Replace or load from env or secrets manager


class BcryptCryptoConfig(CryptoConfig):
provider: str = "bcrypt"
Expand All @@ -29,11 +31,20 @@ def validate_config(self) -> None:
super().validate_config()
if self.provider not in self.supported_providers:
raise ValueError(f"Unsupported crypto provider: {self.provider}")
if not self.secret_key:
# In production, fail here if no secret key is provided.
raise ValueError(
"No secret key provided for BcryptCryptoProvider."
)
if self.bcrypt_rounds < 4 or self.bcrypt_rounds > 31:
raise ValueError("bcrypt_rounds must be between 4 and 31")

def verify_password(
self, plain_password: str, hashed_password: str
) -> bool:
try:
# First try to decode as base64 (new format)
stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
except:
# If that fails, treat as raw bcrypt hash (old format)
stored_hash = hashed_password.encode("utf-8")

return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)


class BCryptCryptoProvider(CryptoProvider, ABC):
Expand All @@ -47,7 +58,11 @@ def __init__(self, config: BcryptCryptoConfig):

# Load the secret key for JWT
# No fallback defaults: fail if not provided
self.secret_key = self.config.secret_key
self.secret_key = (
config.secret_key
or os.getenv("R2R_SECRET_KEY")
or DEFAULT_BCRYPT_SECRET_KEY
)
if not self.secret_key:
raise ValueError(
"No secret key provided for BcryptCryptoProvider."
Expand All @@ -64,7 +79,13 @@ def get_password_hash(self, password: str) -> str:
def verify_password(
self, plain_password: str, hashed_password: str
) -> bool:
stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
try:
# First try to decode as base64 (new format)
stored_hash = base64.b64decode(hashed_password.encode("utf-8"))
except:
# If that fails, treat as raw bcrypt hash (old format)
stored_hash = hashed_password.encode("utf-8")

return bcrypt.checkpw(plain_password.encode("utf-8"), stored_hash)

def generate_verification_code(self, length: int = 32) -> str:
Expand Down

0 comments on commit 0ddfd13

Please sign in to comment.