Skip to content

Commit 8562b73

Browse files
committed
Search. Part - 1. Tavily, You.com, SerpAPI.
1 parent 80890d9 commit 8562b73

7 files changed

+734
-1
lines changed

aisuite/search_client.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from .search_provider import SearchProviderFactory
2+
3+
4+
class SearchClient:
5+
def __init__(self, provider_configs: dict = {}):
6+
"""
7+
Initialize the search client with provider configurations.
8+
Use the SearchProviderFactory to create provider instances.
9+
10+
Args:
11+
provider_configs (dict): A dictionary containing provider configurations.
12+
Each key should be a provider string (e.g., "you.com" or "google"),
13+
and the value should be a dictionary of configuration options for that provider.
14+
For example:
15+
{
16+
"you.com": {"api_key": "your_youcom_api_key"},
17+
"google": {
18+
"api_key": "your_google_api_key",
19+
"cx": "your_google_cx"
20+
}
21+
}
22+
"""
23+
self.providers = {}
24+
self.provider_configs = provider_configs
25+
self._initialize_providers()
26+
27+
def _initialize_providers(self):
28+
"""Helper method to initialize or update providers."""
29+
for provider_key, config in self.provider_configs.items():
30+
self._validate_provider_key(provider_key)
31+
self.providers[provider_key] = SearchProviderFactory.create_provider(
32+
provider_key, config
33+
)
34+
35+
def _validate_provider_key(self, provider_key):
36+
"""
37+
Validate if the provider key corresponds to a supported provider.
38+
"""
39+
supported_providers = SearchProviderFactory.get_supported_providers()
40+
41+
if provider_key not in supported_providers:
42+
raise ValueError(
43+
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
44+
"Make sure the provider string is formatted correctly as 'provider' or 'provider:function'."
45+
)
46+
47+
def configure(self, provider_configs: dict = None):
48+
"""
49+
Configure the client with provider configurations.
50+
"""
51+
if provider_configs is None:
52+
return
53+
54+
self.provider_configs.update(provider_configs)
55+
self._initialize_providers()
56+
57+
def search(self, provider: str, query: str, **kwargs):
58+
"""
59+
Perform a search using the specified provider and query.
60+
61+
Args:
62+
provider (str): The provider to use, can be in format "provider" or "provider:specific_function"
63+
query (str): The search query string
64+
**kwargs: Additional arguments to pass to the search provider
65+
66+
Returns:
67+
List[SearchResult]: A list of search results from the provider
68+
69+
Examples:
70+
>>> client.search("you.com", "who is messi")
71+
>>> client.search("you.com:get_news", "who is messi")
72+
"""
73+
provider_key, _, specific_function = provider.partition(":")
74+
75+
# Validate if the provider is supported
76+
self._validate_provider_key(provider_key)
77+
78+
# Initialize provider if not already initialized
79+
if provider_key not in self.providers:
80+
config = self.provider_configs.get(provider_key, {})
81+
self.providers[provider_key] = SearchProviderFactory.create_provider(
82+
provider_key, config
83+
)
84+
85+
provider_instance = self.providers.get(provider_key)
86+
if not provider_instance:
87+
raise ValueError(f"Could not load provider for '{provider_key}'.")
88+
89+
# Perform the search using the provider
90+
return provider_instance.search(
91+
query, specific_function=specific_function, **kwargs
92+
)

aisuite/search_provider.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import List, Dict, TypedDict
2+
from abc import ABC, abstractmethod
3+
from pathlib import Path
4+
import importlib
5+
import functools
6+
from dataclasses import dataclass
7+
8+
9+
@dataclass
10+
class SearchResult:
11+
"""Data class to store search results."""
12+
13+
title: str
14+
url: str
15+
content: str
16+
source: str
17+
18+
19+
class SearchProviderInterface(ABC):
20+
@abstractmethod
21+
def search(self, query: str, specific_function: str = None) -> List[SearchResult]:
22+
"""
23+
Search method that returns a list of search results.
24+
25+
Args:
26+
query (str): The search query string
27+
specific_function (str, optional): Specific search function to use (e.g., 'get_news', 'web_search').
28+
If None, uses the provider's default search function.
29+
30+
Returns:
31+
List[SearchResult]: A list of dictionaries containing search results,
32+
where each dictionary has 'title', 'content', 'url', and 'source' fields
33+
"""
34+
pass
35+
36+
37+
class SearchProviderFactory:
38+
"""Factory to dynamically load search provider instances based on naming conventions."""
39+
40+
PROVIDERS_DIR = Path(__file__).parent / "search_providers"
41+
42+
@classmethod
43+
def create_provider(
44+
cls, provider_key: str, config: Dict
45+
) -> SearchProviderInterface:
46+
"""Dynamically load and create an instance of a search provider."""
47+
# Convert provider_key to the expected module and class names
48+
provider_class_name = f"{provider_key.capitalize()}SearchProvider"
49+
provider_module_name = f"{provider_key}_search_provider"
50+
51+
module_path = f"aisuite.search_providers.{provider_module_name}"
52+
53+
try:
54+
module = importlib.import_module(module_path)
55+
except ImportError as e:
56+
raise ImportError(
57+
f"Could not import module {module_path}: {str(e)}. "
58+
"Please ensure the search provider is supported by calling SearchProviderFactory.get_supported_providers()"
59+
)
60+
61+
# Instantiate the provider class
62+
provider_class = getattr(module, provider_class_name)
63+
return provider_class(**config)
64+
65+
@classmethod
66+
@functools.cache
67+
def get_supported_providers(cls) -> set[str]:
68+
"""List all supported search provider names based on files present in the search_providers directory."""
69+
provider_files = Path(cls.PROVIDERS_DIR).glob("*_search_provider.py")
70+
return {file.stem.replace("_search_provider", "") for file in provider_files}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
from typing import List
3+
from serpapi import (
4+
GoogleSearch,
5+
) # Despite the name, this handles multiple search engines
6+
from aisuite.search_provider import SearchProviderInterface, SearchResult
7+
8+
9+
class SerpSearchProvider(SearchProviderInterface):
10+
"""SerpAPI implementation of the SearchProviderInterface supporting Google, Bing, YouTube and other search engines."""
11+
12+
def __init__(self, api_key: str = None):
13+
"""
14+
Initialize the SerpAPI provider with an API key.
15+
16+
Args:
17+
api_key (str, optional): The API key for accessing SerpAPI.
18+
If not provided, will try to read from SERPAPI_API_KEY environment variable.
19+
"""
20+
self.api_key = api_key or os.getenv("SERP_API_KEY")
21+
if not self.api_key:
22+
raise ValueError(
23+
"SerpAPI key is required. Either pass it to the constructor or set SERP_API_KEY environment variable."
24+
)
25+
26+
def youtube_search(self, query: str, **kwargs) -> List[SearchResult]:
27+
"""
28+
Perform a YouTube search using the SerpAPI.
29+
30+
Args:
31+
query (str): The search query
32+
**kwargs: Additional parameters specific to YouTube search (e.g., type='video')
33+
34+
Returns:
35+
List[SearchResult]: List of search results
36+
"""
37+
params = {"q": query, "api_key": self.api_key, "engine": "youtube", **kwargs}
38+
39+
search = GoogleSearch(params)
40+
results = search.get_dict()
41+
42+
# Handle YouTube video results
43+
video_results = results.get("video_results", [])
44+
45+
return [
46+
SearchResult(
47+
title=result.get("title", ""),
48+
url=result.get("link", ""),
49+
content=f"Duration: {result.get('duration', 'N/A')} | "
50+
f"Views: {result.get('views', 'N/A')} | "
51+
f"Channel: {result.get('channel', {}).get('name', 'N/A')} | "
52+
f"Description: {result.get('description', '')}",
53+
source="serpapi:youtube",
54+
)
55+
for result in video_results
56+
]
57+
58+
def search(
59+
self, query: str, specific_function: str = None, **kwargs
60+
) -> List[SearchResult]:
61+
"""
62+
Perform a search using the SerpAPI.
63+
Note: Although we use GoogleSearch class, it supports multiple engines including Google, Bing,
64+
Baidu, Yahoo, YouTube and others through the 'engine' parameter.
65+
66+
Args:
67+
query (str): The search query
68+
specific_function (str, optional): Search engine to use ('google', 'bing', 'youtube', etc.).
69+
Defaults to Google.
70+
**kwargs: Additional parameters to pass to the API (e.g., num=10, location="Austin, TX")
71+
72+
Returns:
73+
List[SearchResult]: List of search results
74+
"""
75+
print(
76+
f"Calling SerpSearchProvider.search with query: {query}, specific_function: {specific_function}, kwargs: {kwargs}"
77+
)
78+
# Handle YouTube searches separately due to different result structure
79+
if specific_function == "youtube":
80+
return self.youtube_search(query, **kwargs)
81+
82+
params = {
83+
"q": query,
84+
"api_key": self.api_key,
85+
"engine": "google", # Default to Google search
86+
**kwargs,
87+
}
88+
89+
# Override engine if specific_function is provided
90+
if specific_function:
91+
params["engine"] = specific_function
92+
93+
# Perform the search (GoogleSearch class handles all engine types)
94+
search = GoogleSearch(params)
95+
results = search.get_dict()
96+
97+
# Handle organic search results
98+
organic_results = results.get("organic_results", [])
99+
100+
return [
101+
SearchResult(
102+
title=result.get("title", ""),
103+
url=result.get("link", ""),
104+
content=result.get("snippet", ""),
105+
source=f"serpapi:{params['engine']}",
106+
)
107+
for result in organic_results
108+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
from typing import List
3+
from tavily import TavilyClient
4+
from aisuite.search_provider import SearchProviderInterface, SearchResult
5+
6+
7+
class TavilySearchProvider(SearchProviderInterface):
8+
"""Tavily implementation of the SearchProviderInterface."""
9+
10+
def __init__(self, api_key: str = None):
11+
"""
12+
Initialize the Tavily provider with an API key.
13+
14+
Args:
15+
api_key (str, optional): The API key for accessing Tavily API.
16+
If not provided, will try to read from TAVILY_API_KEY environment variable.
17+
"""
18+
self.api_key = api_key or os.getenv("TAVILY_API_KEY")
19+
if not self.api_key:
20+
raise ValueError(
21+
"Tavily API key is required. Either pass it to the constructor or set TAVILY_API_KEY environment variable."
22+
)
23+
self.client = TavilyClient(api_key=self.api_key)
24+
25+
def search(
26+
self, query: str, specific_function: str = None, **kwargs
27+
) -> List[SearchResult]:
28+
"""
29+
Perform a search using the Tavily API.
30+
31+
Args:
32+
query (str): The search query
33+
specific_function (str, optional): Not used for Tavily
34+
**kwargs: Additional parameters to pass to the API
35+
36+
Returns:
37+
List[SearchResult]: List of search results
38+
"""
39+
response = self.client.search(query, **kwargs)
40+
41+
return [
42+
SearchResult(
43+
title=result.get("title", ""),
44+
url=result.get("url", ""),
45+
content=result.get("content", ""),
46+
source="tavily",
47+
)
48+
for result in response.get("results", [])
49+
]

0 commit comments

Comments
 (0)