Skip to content

Commit

Permalink
Search. Part - 1. Tavily, You.com, SerpAPI.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitprasad15 committed Dec 27, 2024
1 parent 80890d9 commit 9ebee40
Show file tree
Hide file tree
Showing 7 changed files with 744 additions and 1 deletion.
92 changes: 92 additions & 0 deletions aisuite/search_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from .search_provider import SearchProviderFactory


class SearchClient:
def __init__(self, provider_configs: dict = {}):
"""
Initialize the search client with provider configurations.
Use the SearchProviderFactory to create provider instances.
Args:
provider_configs (dict): A dictionary containing provider configurations.
Each key should be a provider string (e.g., "you.com" or "google"),
and the value should be a dictionary of configuration options for that provider.
For example:
{
"you.com": {"api_key": "your_youcom_api_key"},
"google": {
"api_key": "your_google_api_key",
"cx": "your_google_cx"
}
}
"""
self.providers = {}
self.provider_configs = provider_configs
self._initialize_providers()

def _initialize_providers(self):
"""Helper method to initialize or update providers."""
for provider_key, config in self.provider_configs.items():
self._validate_provider_key(provider_key)
self.providers[provider_key] = SearchProviderFactory.create_provider(
provider_key, config
)

def _validate_provider_key(self, provider_key):
"""
Validate if the provider key corresponds to a supported provider.
"""
supported_providers = SearchProviderFactory.get_supported_providers()

if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the provider string is formatted correctly as 'provider' or 'provider:function'."
)

def configure(self, provider_configs: dict = None):
"""
Configure the client with provider configurations.
"""
if provider_configs is None:
return

self.provider_configs.update(provider_configs)
self._initialize_providers()

def search(self, provider: str, query: str, **kwargs):
"""
Perform a search using the specified provider and query.
Args:
provider (str): The provider to use, can be in format "provider" or "provider:specific_function"
query (str): The search query string
**kwargs: Additional arguments to pass to the search provider
Returns:
List[SearchResult]: A list of search results from the provider
Examples:
>>> client.search("you.com", "who is messi")
>>> client.search("you.com:get_news", "who is messi")
"""
provider_key, _, specific_function = provider.partition(":")

# Validate if the provider is supported
self._validate_provider_key(provider_key)

# Initialize provider if not already initialized
if provider_key not in self.providers:
config = self.provider_configs.get(provider_key, {})
self.providers[provider_key] = SearchProviderFactory.create_provider(
provider_key, config
)

provider_instance = self.providers.get(provider_key)
if not provider_instance:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Perform the search using the provider
return provider_instance.search(
query, specific_function=specific_function, **kwargs
)
70 changes: 70 additions & 0 deletions aisuite/search_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import List, Dict, TypedDict
from abc import ABC, abstractmethod
from pathlib import Path
import importlib
import functools
from dataclasses import dataclass


@dataclass
class SearchResult:
"""Data class to store search results."""

title: str
url: str
content: str
source: str


class SearchProviderInterface(ABC):
@abstractmethod
def search(self, query: str, specific_function: str = None) -> List[SearchResult]:
"""
Search method that returns a list of search results.
Args:
query (str): The search query string
specific_function (str, optional): Specific search function to use (e.g., 'get_news', 'web_search').
If None, uses the provider's default search function.
Returns:
List[SearchResult]: A list of dictionaries containing search results,
where each dictionary has 'title', 'content', 'url', and 'source' fields
"""
pass


class SearchProviderFactory:
"""Factory to dynamically load search provider instances based on naming conventions."""

PROVIDERS_DIR = Path(__file__).parent / "search_providers"

@classmethod
def create_provider(
cls, provider_key: str, config: Dict
) -> SearchProviderInterface:
"""Dynamically load and create an instance of a search provider."""
# Convert provider_key to the expected module and class names
provider_class_name = f"{provider_key.capitalize()}SearchProvider"
provider_module_name = f"{provider_key}_search_provider"

module_path = f"aisuite.search_providers.{provider_module_name}"

try:
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(
f"Could not import module {module_path}: {str(e)}. "
"Please ensure the search provider is supported by calling SearchProviderFactory.get_supported_providers()"
)

# Instantiate the provider class
provider_class = getattr(module, provider_class_name)
return provider_class(**config)

@classmethod
@functools.cache
def get_supported_providers(cls) -> set[str]:
"""List all supported search provider names based on files present in the search_providers directory."""
provider_files = Path(cls.PROVIDERS_DIR).glob("*_search_provider.py")
return {file.stem.replace("_search_provider", "") for file in provider_files}
108 changes: 108 additions & 0 deletions aisuite/search_providers/serp_search_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
from typing import List
from serpapi import (
GoogleSearch,
) # Despite the name, this handles multiple search engines
from aisuite.search_provider import SearchProviderInterface, SearchResult


class SerpSearchProvider(SearchProviderInterface):
"""SerpAPI implementation of the SearchProviderInterface supporting Google, Bing, YouTube and other search engines."""

def __init__(self, api_key: str = None):
"""
Initialize the SerpAPI provider with an API key.
Args:
api_key (str, optional): The API key for accessing SerpAPI.
If not provided, will try to read from SERPAPI_API_KEY environment variable.
"""
self.api_key = api_key or os.getenv("SERP_API_KEY")
if not self.api_key:
raise ValueError(
"SerpAPI key is required. Either pass it to the constructor or set SERP_API_KEY environment variable."
)

def youtube_search(self, query: str, **kwargs) -> List[SearchResult]:
"""
Perform a YouTube search using the SerpAPI.
Args:
query (str): The search query
**kwargs: Additional parameters specific to YouTube search (e.g., type='video')
Returns:
List[SearchResult]: List of search results
"""
params = {"q": query, "api_key": self.api_key, "engine": "youtube", **kwargs}

search = GoogleSearch(params)
results = search.get_dict()

# Handle YouTube video results
video_results = results.get("video_results", [])

return [
SearchResult(
title=result.get("title", ""),
url=result.get("link", ""),
content=f"Duration: {result.get('duration', 'N/A')} | "
f"Views: {result.get('views', 'N/A')} | "
f"Channel: {result.get('channel', {}).get('name', 'N/A')} | "
f"Description: {result.get('description', '')}",
source="serpapi:youtube",
)
for result in video_results
]

def search(
self, query: str, specific_function: str = None, **kwargs
) -> List[SearchResult]:
"""
Perform a search using the SerpAPI.
Note: Although we use GoogleSearch class, it supports multiple engines including Google, Bing,
Baidu, Yahoo, YouTube and others through the 'engine' parameter.
Args:
query (str): The search query
specific_function (str, optional): Search engine to use ('google', 'bing', 'youtube', etc.).
Defaults to Google.
**kwargs: Additional parameters to pass to the API (e.g., num=10, location="Austin, TX")
Returns:
List[SearchResult]: List of search results
"""
print(
f"Calling SerpSearchProvider.search with query: {query}, specific_function: {specific_function}, kwargs: {kwargs}"
)
# Handle YouTube searches separately due to different result structure
if specific_function == "youtube":
return self.youtube_search(query, **kwargs)

params = {
"q": query,
"api_key": self.api_key,
"engine": "google", # Default to Google search
**kwargs,
}

# Override engine if specific_function is provided
if specific_function:
params["engine"] = specific_function

# Perform the search (GoogleSearch class handles all engine types)
search = GoogleSearch(params)
results = search.get_dict()

# Handle organic search results
organic_results = results.get("organic_results", [])

return [
SearchResult(
title=result.get("title", ""),
url=result.get("link", ""),
content=result.get("snippet", ""),
source=f"serpapi:{params['engine']}",
)
for result in organic_results
]
49 changes: 49 additions & 0 deletions aisuite/search_providers/tavily_search_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from typing import List
from tavily import TavilyClient
from aisuite.search_provider import SearchProviderInterface, SearchResult


class TavilySearchProvider(SearchProviderInterface):
"""Tavily implementation of the SearchProviderInterface."""

def __init__(self, api_key: str = None):
"""
Initialize the Tavily provider with an API key.
Args:
api_key (str, optional): The API key for accessing Tavily API.
If not provided, will try to read from TAVILY_API_KEY environment variable.
"""
self.api_key = api_key or os.getenv("TAVILY_API_KEY")
if not self.api_key:
raise ValueError(
"Tavily API key is required. Either pass it to the constructor or set TAVILY_API_KEY environment variable."
)
self.client = TavilyClient(api_key=self.api_key)

def search(
self, query: str, specific_function: str = None, **kwargs
) -> List[SearchResult]:
"""
Perform a search using the Tavily API.
Args:
query (str): The search query
specific_function (str, optional): Not used for Tavily
**kwargs: Additional parameters to pass to the API
Returns:
List[SearchResult]: List of search results
"""
response = self.client.search(query, **kwargs)

return [
SearchResult(
title=result.get("title", ""),
url=result.get("url", ""),
content=result.get("content", ""),
source="tavily",
)
for result in response.get("results", [])
]
Loading

0 comments on commit 9ebee40

Please sign in to comment.