7
7
from tavily import TavilyClient
8
8
from urllib3 .util import parse_url
9
9
10
+ tool_name = os .getenv ("TAVILY_TOOL_NAME" , "Tavily" )
11
+ if len (sys .argv ) > 1 :
12
+ tool_name = sys .argv [1 ]
13
+
10
14
11
15
def main ():
12
16
if len (sys .argv ) < 2 :
13
17
print ("Usage: python main.py [search | extract]" )
14
18
sys .exit (1 )
15
19
16
20
command = sys .argv [1 ]
17
- client = TavilyClient () # env TAVILY_API_KEY required
18
21
19
22
match command :
20
- case "search" | "safe-search" :
23
+ case "site-search-context" :
24
+ site_search_context ()
25
+ sys .exit (0 )
26
+ case "search" | "site-search" :
27
+ client = TavilyClient () # env TAVILY_API_KEY required
21
28
query = os .getenv ("QUERY" , "" ).strip ()
22
29
if not query :
23
30
print ("No search query provided" )
24
31
sys .exit (1 )
25
32
26
- domains_str = os .getenv ("INCLUDE_DOMAINS" , "" )
27
- include_domains = [
28
- domain .strip () for domain in domains_str .split ("," ) if domain .strip ()
29
- ]
30
-
31
- # safe-search is a special case where we only allow certain domains
33
+ # site-search is a special case where we only allow certain domains
32
34
# this is a different command so that we can use the same code for different tool implementations
33
- if command == "safe-search" :
34
- include_domains = check_allowed_include_domains (include_domains )
35
-
36
- max_results = 10 # broader search if general,
35
+ if command == "site-search" :
36
+ include_domains = get_allowed_domains_or_fail ()
37
+ else :
38
+ domains_str = os .getenv ("INCLUDE_DOMAINS" , "" )
39
+ include_domains = [
40
+ domain .strip () for domain in domains_str .split ("," ) if domain .strip ()
41
+ ]
42
+
43
+ max_results = 5 # broader search if general,
37
44
if len (include_domains ) > 0 :
38
45
max_results = 3 * len (
39
46
include_domains
@@ -51,6 +58,7 @@ def main():
51
58
include_domains = include_domains ,
52
59
)
53
60
case "extract" :
61
+ client = TavilyClient () # env TAVILY_API_KEY required
54
62
url = parse_url (os .getenv ("URL" ).strip ())
55
63
56
64
# default to https:// if no scheme is provided
@@ -76,43 +84,33 @@ def main():
76
84
# print the response as a valid json object
77
85
print (json .dumps (response ))
78
86
79
-
80
- def check_allowed_include_domains (include_domains : List [str ]) -> List [str ]:
81
- # TAVILY_ALLOWED_DOMAINS has the TAVILY_ prefix as it will be set by Obot directly in the env,
82
- # while e.g. INCLUDE_DOMAINS is a tool parameter
83
- allowed_domains_str = os .getenv ("TAVILY_ALLOWED_DOMAINS" , "" )
84
- allowed_domains = [
85
- domain .strip () for domain in allowed_domains_str .split ("," ) if domain .strip ()
86
- ]
87
-
88
- if len (allowed_domains ) == 0 :
89
- print ("No allowed domains provided" )
87
+ def site_search_context ():
88
+ print (f"""WEBSITE KNOWLEDGE:
89
+ Use the { tool_name } website knowledge tool to search the following"
90
+ configured domains:
91
+ """ )
92
+ config = json .loads (os .getenv ("OBOT_WEBSITE_KNOWLEDGE" , "{}" ))
93
+ for site_def in config .get ("sites" , []):
94
+ site = site_def .get ("site" , "" )
95
+ description = site_def .get ("description" , "" )
96
+ if site :
97
+ print (f"DOMAIN: { site } \n " )
98
+ if description :
99
+ print (f"DESCRIPTION: { description } \n " )
100
+ print (f"""END WEBSITE KNOWLEDGE
101
+ """ )
102
+
103
+ def get_allowed_domains_or_fail () -> List [str ]:
104
+ result = []
105
+ config = json .loads (os .getenv ("OBOT_WEBSITE_KNOWLEDGE" , "{}" ))
106
+ for site_def in config .get ("sites" , []):
107
+ site = site_def .get ("site" , "" )
108
+ if site :
109
+ result .append (site )
110
+ if len (result ) == 0 :
111
+ logging .error ("No allowed domains found in OBOT_WEBSITE_KNOWLEDGE" )
90
112
sys .exit (1 )
91
-
92
- # allow not setting INCLUDE_DOMAINS - fallback to all allowed domains
93
- if len (include_domains ) == 0 :
94
- return allowed_domains
95
-
96
- allowed_include_domains = []
97
- disallowed_include_domains = []
98
-
99
- for domain in include_domains :
100
- if domain in allowed_domains :
101
- allowed_include_domains .append (domain )
102
- else :
103
- disallowed_include_domains .append (domain )
104
-
105
- if len (disallowed_include_domains ) > 0 :
106
- if os .getenv ("TAVILY_ALLOWED_DOMAINS_STRICT" , "" ).lower () == "true" :
107
- print (
108
- f"Tried to access domains { disallowed_include_domains } which are not listed in allowed domains { allowed_domains } "
109
- )
110
- sys .exit (1 )
111
- logging .warning (
112
- f"Filtered out { disallowed_include_domains } as they are not listed in allowed domains { allowed_domains } . Continuing with { allowed_include_domains } "
113
- )
114
- include_domains = allowed_include_domains
115
- return include_domains
113
+ return result
116
114
117
115
118
116
if __name__ == "__main__" :
0 commit comments