Skip to content

Commit

Permalink
Sendgrid Email Provider Implementation (SciPhi-AI#1614) (SciPhi-AI#1618)
Browse files Browse the repository at this point in the history
* +sendgrid email provider

* Update py/tests/core/providers/email/test_email_providers.py

The template_id parameter shown here is an example and is not intended to represent actual data. I included it as a placeholder.



* Update py/tests/core/providers/email/test_email_providers.py

The template_id parameter shown here is an example and is not intended to represent actual data. I included it as a placeholder.



---------

Co-authored-by: logerzerox <[email protected]>
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 20, 2024
1 parent b963121 commit 5c99a84
Show file tree
Hide file tree
Showing 12 changed files with 376 additions and 47 deletions.
38 changes: 22 additions & 16 deletions py/core/base/providers/email.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# email_provider.py
import logging
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Dict
import os

from .base import Provider, ProviderConfig

Expand All @@ -13,27 +14,22 @@ class EmailConfig(ProviderConfig):
smtp_password: Optional[str] = None
from_email: Optional[str] = None
use_tls: Optional[bool] = True

sendgrid_api_key: Optional[str] = None
verify_email_template_id: Optional[str] = None
reset_password_template_id: Optional[str] = None
frontend_url: Optional[str] = None
@property
def supported_providers(self) -> list[str]:
return [
"smtp",
"console",
"sendgrid",
] # Could add more providers like AWS SES, SendGrid etc.

def validate_config(self) -> None:
pass
# if self.provider == "smtp":
# if not all(
# [
# self.smtp_server,
# self.smtp_port,
# self.smtp_username,
# self.smtp_password,
# self.from_email,
# ]
# ):
# raise ValueError("SMTP configuration is incomplete")
if self.provider == "sendgrid":
if not (self.sendgrid_api_key or os.getenv("SENDGRID_API_KEY")):
raise ValueError("SendGrid API key is required when using SendGrid provider")


logger = logging.getLogger(__name__)
Expand All @@ -55,17 +51,27 @@ async def send_email(
subject: str,
body: str,
html_body: Optional[str] = None,
*args,
**kwargs
) -> None:
pass

@abstractmethod
async def send_verification_email(
self, to_email: str, verification_code: str
self,
to_email: str,
verification_code: str,
*args,
**kwargs
) -> None:
pass

@abstractmethod
async def send_password_reset_email(
self, to_email: str, reset_token: str
self,
to_email: str,
reset_token: str,
*args,
**kwargs
) -> None:
pass
3 changes: 2 additions & 1 deletion py/core/main/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SqlitePersistentLoggingProvider,
SupabaseAuthProvider,
UnstructuredIngestionProvider,
SendGridEmailProvider,
)


Expand All @@ -38,7 +39,7 @@ class R2RProviders(BaseModel):
HatchetOrchestrationProvider, SimpleOrchestrationProvider
]
logging: SqlitePersistentLoggingProvider
email: Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]
email: Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider]

class Config:
arbitrary_types_allowed = True
Expand Down
10 changes: 7 additions & 3 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe
from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider

from core.providers.email.sendgrid import SendGridEmailProvider

from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
from ..config import R2RConfig

Expand Down Expand Up @@ -56,7 +58,7 @@ async def create_auth_provider(
crypto_provider: BCryptProvider,
database_provider: PostgresDBProvider,
email_provider: Union[
AsyncSMTPEmailProvider, ConsoleMockEmailProvider
AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider
],
*args,
**kwargs,
Expand Down Expand Up @@ -235,7 +237,7 @@ def create_llm_provider(
@staticmethod
async def create_email_provider(
email_config: Optional[EmailConfig] = None, *args, **kwargs
) -> Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]:
) -> Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider]:
"""Creates an email provider based on configuration."""
if not email_config:
raise ValueError(
Expand All @@ -246,6 +248,8 @@ async def create_email_provider(
return AsyncSMTPEmailProvider(email_config)
elif email_config.provider == "console_mock":
return ConsoleMockEmailProvider(email_config)
elif email_config.provider == "sendgrid":
return SendGridEmailProvider(email_config)
else:
raise ValueError(
f"Email provider {email_config.provider} not supported."
Expand All @@ -259,7 +263,7 @@ async def create_providers(
crypto_provider_override: Optional[BCryptProvider] = None,
database_provider_override: Optional[PostgresDBProvider] = None,
email_provider_override: Optional[
Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider]
Union[AsyncSMTPEmailProvider, ConsoleMockEmailProvider, SendGridEmailProvider]
] = None,
embedding_provider_override: Optional[
Union[
Expand Down
7 changes: 6 additions & 1 deletion py/core/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from .auth import R2RAuthProvider, SupabaseAuthProvider
from .crypto import BCryptConfig, BCryptProvider
from .database import PostgresDBProvider
from .email import AsyncSMTPEmailProvider, ConsoleMockEmailProvider
from .email import (
AsyncSMTPEmailProvider,
ConsoleMockEmailProvider,
SendGridEmailProvider,
)
from .embeddings import (
LiteLLMEmbeddingProvider,
OllamaEmbeddingProvider,
Expand Down Expand Up @@ -41,6 +45,7 @@
# Email
"AsyncSMTPEmailProvider",
"ConsoleMockEmailProvider",
"SendGridEmailProvider",
# Orchestration
"HatchetOrchestrationProvider",
"SimpleOrchestrationProvider",
Expand Down
29 changes: 10 additions & 19 deletions py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ async def register(self, email: str, password: str) -> UserResponse:
)

if self.config.require_email_verification:
# Generate verification code and send email
verification_code = (
self.crypto_provider.generate_verification_code()
)
Expand All @@ -157,10 +156,12 @@ async def register(self, email: str, password: str) -> UserResponse:
new_user.id, verification_code, expiry
)
new_user.verification_code_expiry = expiry
# TODO - Integrate email provider(s)

# Safely get first name, defaulting to email if name is None
first_name = new_user.name.split(" ")[0] if new_user.name else email.split("@")[0]

await self.email_provider.send_verification_email(
new_user.email, verification_code
new_user.email, verification_code, {"first_name": first_name}
)
else:
expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
Expand Down Expand Up @@ -307,8 +308,9 @@ async def request_password_reset(self, email: str) -> dict[str, str]:
user.id, reset_token, expiry
)

# TODO: Integrate with email provider to send reset link
await self.email_provider.send_password_reset_email(email, reset_token)
# Safely get first name, defaulting to email if name is None
first_name = user.name.split(" ")[0] if user.name else email.split("@")[0]
await self.email_provider.send_password_reset_email(email, reset_token, {"first_name": first_name})

return {"message": "If the email exists, a reset link has been sent"}

Expand Down Expand Up @@ -341,19 +343,6 @@ async def clean_expired_blacklisted_tokens(self):
await self.database_provider.clean_expired_blacklisted_tokens()

async def send_reset_email(self, email: str) -> dict:
"""
Generate a new verification code and send a reset email to the user.
Returns the verification code for testing/sandbox environments.
Args:
email (str): The email address of the user
Returns:
dict: Contains verification_code and message
Raises:
R2RException: If user is not found
"""
user = await self.database_provider.get_user_by_email(email)
if not user:
raise R2RException(status_code=404, message="User not found")
Expand All @@ -369,9 +358,11 @@ async def send_reset_email(self, email: str) -> dict:
expiry,
)

# Safely get first name, defaulting to email if name is None
first_name = user.name.split(" ")[0] if user.name else email.split("@")[0]
# Send verification email
await self.email_provider.send_verification_email(
email, verification_code
email, verification_code, {"first_name": first_name}
)

return {
Expand Down
7 changes: 6 additions & 1 deletion py/core/providers/email/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .console_mock import ConsoleMockEmailProvider
from .smtp import AsyncSMTPEmailProvider
from .sendgrid import SendGridEmailProvider

__all__ = ["ConsoleMockEmailProvider", "AsyncSMTPEmailProvider"]
__all__ = [
"ConsoleMockEmailProvider",
"AsyncSMTPEmailProvider",
"SendGridEmailProvider",
]
14 changes: 12 additions & 2 deletions py/core/providers/email/console_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ async def send_email(
subject: str,
body: str,
html_body: Optional[str] = None,
*args,
**kwargs
) -> None:
logger.info(
f"""
Expand All @@ -28,7 +30,11 @@ async def send_email(
)

async def send_verification_email(
self, to_email: str, verification_code: str
self,
to_email: str,
verification_code: str,
*args,
**kwargs
) -> None:
logger.info(
f"""
Expand All @@ -42,7 +48,11 @@ async def send_verification_email(
)

async def send_password_reset_email(
self, to_email: str, reset_token: str
self,
to_email: str,
reset_token: str,
*args,
**kwargs
) -> None:
logger.info(
f"""
Expand Down
Loading

0 comments on commit 5c99a84

Please sign in to comment.