-
Notifications
You must be signed in to change notification settings - Fork 901
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Search. Part - 1. Tavily, You.com, SerpAPI.
- Loading branch information
1 parent
80890d9
commit 9ebee40
Showing
7 changed files
with
744 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from .search_provider import SearchProviderFactory | ||
|
||
|
||
class SearchClient: | ||
def __init__(self, provider_configs: dict = {}): | ||
""" | ||
Initialize the search client with provider configurations. | ||
Use the SearchProviderFactory to create provider instances. | ||
Args: | ||
provider_configs (dict): A dictionary containing provider configurations. | ||
Each key should be a provider string (e.g., "you.com" or "google"), | ||
and the value should be a dictionary of configuration options for that provider. | ||
For example: | ||
{ | ||
"you.com": {"api_key": "your_youcom_api_key"}, | ||
"google": { | ||
"api_key": "your_google_api_key", | ||
"cx": "your_google_cx" | ||
} | ||
} | ||
""" | ||
self.providers = {} | ||
self.provider_configs = provider_configs | ||
self._initialize_providers() | ||
|
||
def _initialize_providers(self): | ||
"""Helper method to initialize or update providers.""" | ||
for provider_key, config in self.provider_configs.items(): | ||
self._validate_provider_key(provider_key) | ||
self.providers[provider_key] = SearchProviderFactory.create_provider( | ||
provider_key, config | ||
) | ||
|
||
def _validate_provider_key(self, provider_key): | ||
""" | ||
Validate if the provider key corresponds to a supported provider. | ||
""" | ||
supported_providers = SearchProviderFactory.get_supported_providers() | ||
|
||
if provider_key not in supported_providers: | ||
raise ValueError( | ||
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " | ||
"Make sure the provider string is formatted correctly as 'provider' or 'provider:function'." | ||
) | ||
|
||
def configure(self, provider_configs: dict = None): | ||
""" | ||
Configure the client with provider configurations. | ||
""" | ||
if provider_configs is None: | ||
return | ||
|
||
self.provider_configs.update(provider_configs) | ||
self._initialize_providers() | ||
|
||
def search(self, provider: str, query: str, **kwargs): | ||
""" | ||
Perform a search using the specified provider and query. | ||
Args: | ||
provider (str): The provider to use, can be in format "provider" or "provider:specific_function" | ||
query (str): The search query string | ||
**kwargs: Additional arguments to pass to the search provider | ||
Returns: | ||
List[SearchResult]: A list of search results from the provider | ||
Examples: | ||
>>> client.search("you.com", "who is messi") | ||
>>> client.search("you.com:get_news", "who is messi") | ||
""" | ||
provider_key, _, specific_function = provider.partition(":") | ||
|
||
# Validate if the provider is supported | ||
self._validate_provider_key(provider_key) | ||
|
||
# Initialize provider if not already initialized | ||
if provider_key not in self.providers: | ||
config = self.provider_configs.get(provider_key, {}) | ||
self.providers[provider_key] = SearchProviderFactory.create_provider( | ||
provider_key, config | ||
) | ||
|
||
provider_instance = self.providers.get(provider_key) | ||
if not provider_instance: | ||
raise ValueError(f"Could not load provider for '{provider_key}'.") | ||
|
||
# Perform the search using the provider | ||
return provider_instance.search( | ||
query, specific_function=specific_function, **kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import List, Dict, TypedDict | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
import importlib | ||
import functools | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class SearchResult: | ||
"""Data class to store search results.""" | ||
|
||
title: str | ||
url: str | ||
content: str | ||
source: str | ||
|
||
|
||
class SearchProviderInterface(ABC): | ||
@abstractmethod | ||
def search(self, query: str, specific_function: str = None) -> List[SearchResult]: | ||
""" | ||
Search method that returns a list of search results. | ||
Args: | ||
query (str): The search query string | ||
specific_function (str, optional): Specific search function to use (e.g., 'get_news', 'web_search'). | ||
If None, uses the provider's default search function. | ||
Returns: | ||
List[SearchResult]: A list of dictionaries containing search results, | ||
where each dictionary has 'title', 'content', 'url', and 'source' fields | ||
""" | ||
pass | ||
|
||
|
||
class SearchProviderFactory: | ||
"""Factory to dynamically load search provider instances based on naming conventions.""" | ||
|
||
PROVIDERS_DIR = Path(__file__).parent / "search_providers" | ||
|
||
@classmethod | ||
def create_provider( | ||
cls, provider_key: str, config: Dict | ||
) -> SearchProviderInterface: | ||
"""Dynamically load and create an instance of a search provider.""" | ||
# Convert provider_key to the expected module and class names | ||
provider_class_name = f"{provider_key.capitalize()}SearchProvider" | ||
provider_module_name = f"{provider_key}_search_provider" | ||
|
||
module_path = f"aisuite.search_providers.{provider_module_name}" | ||
|
||
try: | ||
module = importlib.import_module(module_path) | ||
except ImportError as e: | ||
raise ImportError( | ||
f"Could not import module {module_path}: {str(e)}. " | ||
"Please ensure the search provider is supported by calling SearchProviderFactory.get_supported_providers()" | ||
) | ||
|
||
# Instantiate the provider class | ||
provider_class = getattr(module, provider_class_name) | ||
return provider_class(**config) | ||
|
||
@classmethod | ||
@functools.cache | ||
def get_supported_providers(cls) -> set[str]: | ||
"""List all supported search provider names based on files present in the search_providers directory.""" | ||
provider_files = Path(cls.PROVIDERS_DIR).glob("*_search_provider.py") | ||
return {file.stem.replace("_search_provider", "") for file in provider_files} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
from typing import List | ||
from serpapi import ( | ||
GoogleSearch, | ||
) # Despite the name, this handles multiple search engines | ||
from aisuite.search_provider import SearchProviderInterface, SearchResult | ||
|
||
|
||
class SerpSearchProvider(SearchProviderInterface): | ||
"""SerpAPI implementation of the SearchProviderInterface supporting Google, Bing, YouTube and other search engines.""" | ||
|
||
def __init__(self, api_key: str = None): | ||
""" | ||
Initialize the SerpAPI provider with an API key. | ||
Args: | ||
api_key (str, optional): The API key for accessing SerpAPI. | ||
If not provided, will try to read from SERPAPI_API_KEY environment variable. | ||
""" | ||
self.api_key = api_key or os.getenv("SERP_API_KEY") | ||
if not self.api_key: | ||
raise ValueError( | ||
"SerpAPI key is required. Either pass it to the constructor or set SERP_API_KEY environment variable." | ||
) | ||
|
||
def youtube_search(self, query: str, **kwargs) -> List[SearchResult]: | ||
""" | ||
Perform a YouTube search using the SerpAPI. | ||
Args: | ||
query (str): The search query | ||
**kwargs: Additional parameters specific to YouTube search (e.g., type='video') | ||
Returns: | ||
List[SearchResult]: List of search results | ||
""" | ||
params = {"q": query, "api_key": self.api_key, "engine": "youtube", **kwargs} | ||
|
||
search = GoogleSearch(params) | ||
results = search.get_dict() | ||
|
||
# Handle YouTube video results | ||
video_results = results.get("video_results", []) | ||
|
||
return [ | ||
SearchResult( | ||
title=result.get("title", ""), | ||
url=result.get("link", ""), | ||
content=f"Duration: {result.get('duration', 'N/A')} | " | ||
f"Views: {result.get('views', 'N/A')} | " | ||
f"Channel: {result.get('channel', {}).get('name', 'N/A')} | " | ||
f"Description: {result.get('description', '')}", | ||
source="serpapi:youtube", | ||
) | ||
for result in video_results | ||
] | ||
|
||
def search( | ||
self, query: str, specific_function: str = None, **kwargs | ||
) -> List[SearchResult]: | ||
""" | ||
Perform a search using the SerpAPI. | ||
Note: Although we use GoogleSearch class, it supports multiple engines including Google, Bing, | ||
Baidu, Yahoo, YouTube and others through the 'engine' parameter. | ||
Args: | ||
query (str): The search query | ||
specific_function (str, optional): Search engine to use ('google', 'bing', 'youtube', etc.). | ||
Defaults to Google. | ||
**kwargs: Additional parameters to pass to the API (e.g., num=10, location="Austin, TX") | ||
Returns: | ||
List[SearchResult]: List of search results | ||
""" | ||
print( | ||
f"Calling SerpSearchProvider.search with query: {query}, specific_function: {specific_function}, kwargs: {kwargs}" | ||
) | ||
# Handle YouTube searches separately due to different result structure | ||
if specific_function == "youtube": | ||
return self.youtube_search(query, **kwargs) | ||
|
||
params = { | ||
"q": query, | ||
"api_key": self.api_key, | ||
"engine": "google", # Default to Google search | ||
**kwargs, | ||
} | ||
|
||
# Override engine if specific_function is provided | ||
if specific_function: | ||
params["engine"] = specific_function | ||
|
||
# Perform the search (GoogleSearch class handles all engine types) | ||
search = GoogleSearch(params) | ||
results = search.get_dict() | ||
|
||
# Handle organic search results | ||
organic_results = results.get("organic_results", []) | ||
|
||
return [ | ||
SearchResult( | ||
title=result.get("title", ""), | ||
url=result.get("link", ""), | ||
content=result.get("snippet", ""), | ||
source=f"serpapi:{params['engine']}", | ||
) | ||
for result in organic_results | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
from typing import List | ||
from tavily import TavilyClient | ||
from aisuite.search_provider import SearchProviderInterface, SearchResult | ||
|
||
|
||
class TavilySearchProvider(SearchProviderInterface): | ||
"""Tavily implementation of the SearchProviderInterface.""" | ||
|
||
def __init__(self, api_key: str = None): | ||
""" | ||
Initialize the Tavily provider with an API key. | ||
Args: | ||
api_key (str, optional): The API key for accessing Tavily API. | ||
If not provided, will try to read from TAVILY_API_KEY environment variable. | ||
""" | ||
self.api_key = api_key or os.getenv("TAVILY_API_KEY") | ||
if not self.api_key: | ||
raise ValueError( | ||
"Tavily API key is required. Either pass it to the constructor or set TAVILY_API_KEY environment variable." | ||
) | ||
self.client = TavilyClient(api_key=self.api_key) | ||
|
||
def search( | ||
self, query: str, specific_function: str = None, **kwargs | ||
) -> List[SearchResult]: | ||
""" | ||
Perform a search using the Tavily API. | ||
Args: | ||
query (str): The search query | ||
specific_function (str, optional): Not used for Tavily | ||
**kwargs: Additional parameters to pass to the API | ||
Returns: | ||
List[SearchResult]: List of search results | ||
""" | ||
response = self.client.search(query, **kwargs) | ||
|
||
return [ | ||
SearchResult( | ||
title=result.get("title", ""), | ||
url=result.get("url", ""), | ||
content=result.get("content", ""), | ||
source="tavily", | ||
) | ||
for result in response.get("results", []) | ||
] |
Oops, something went wrong.